From 1dac2263372df2b85db5d029a45721fa158a5c9d Mon Sep 17 00:00:00 2001 From: xiubuzhe Date: Sun, 8 Oct 2023 20:59:00 +0800 Subject: first add files --- lib/sqlalchemy/sql/__init__.py | 150 + lib/sqlalchemy/sql/annotation.py | 364 ++ lib/sqlalchemy/sql/base.py | 1702 ++++++++ lib/sqlalchemy/sql/coercions.py | 1096 +++++ lib/sqlalchemy/sql/compiler.py | 5525 ++++++++++++++++++++++++ lib/sqlalchemy/sql/crud.py | 1091 +++++ lib/sqlalchemy/sql/ddl.py | 1341 ++++++ lib/sqlalchemy/sql/default_comparator.py | 360 ++ lib/sqlalchemy/sql/dml.py | 1514 +++++++ lib/sqlalchemy/sql/elements.py | 5415 +++++++++++++++++++++++ lib/sqlalchemy/sql/events.py | 331 ++ lib/sqlalchemy/sql/expression.py | 278 ++ lib/sqlalchemy/sql/functions.py | 1575 +++++++ lib/sqlalchemy/sql/lambdas.py | 1314 ++++++ lib/sqlalchemy/sql/naming.py | 210 + lib/sqlalchemy/sql/operators.py | 1688 ++++++++ lib/sqlalchemy/sql/roles.py | 239 + lib/sqlalchemy/sql/schema.py | 5268 ++++++++++++++++++++++ lib/sqlalchemy/sql/selectable.py | 6946 ++++++++++++++++++++++++++++++ lib/sqlalchemy/sql/sqltypes.py | 3351 ++++++++++++++ lib/sqlalchemy/sql/traversals.py | 1559 +++++++ lib/sqlalchemy/sql/type_api.py | 1974 +++++++++ lib/sqlalchemy/sql/util.py | 1120 +++++ lib/sqlalchemy/sql/visitors.py | 852 ++++ 24 files changed, 45263 insertions(+) create mode 100644 lib/sqlalchemy/sql/__init__.py create mode 100644 lib/sqlalchemy/sql/annotation.py create mode 100644 lib/sqlalchemy/sql/base.py create mode 100644 lib/sqlalchemy/sql/coercions.py create mode 100644 lib/sqlalchemy/sql/compiler.py create mode 100644 lib/sqlalchemy/sql/crud.py create mode 100644 lib/sqlalchemy/sql/ddl.py create mode 100644 lib/sqlalchemy/sql/default_comparator.py create mode 100644 lib/sqlalchemy/sql/dml.py create mode 100644 lib/sqlalchemy/sql/elements.py create mode 100644 lib/sqlalchemy/sql/events.py create mode 100644 lib/sqlalchemy/sql/expression.py create mode 100644 lib/sqlalchemy/sql/functions.py create mode 100644 lib/sqlalchemy/sql/lambdas.py create mode 100644 lib/sqlalchemy/sql/naming.py create mode 100644 lib/sqlalchemy/sql/operators.py create mode 100644 lib/sqlalchemy/sql/roles.py create mode 100644 lib/sqlalchemy/sql/schema.py create mode 100644 lib/sqlalchemy/sql/selectable.py create mode 100644 lib/sqlalchemy/sql/sqltypes.py create mode 100644 lib/sqlalchemy/sql/traversals.py create mode 100644 lib/sqlalchemy/sql/type_api.py create mode 100644 lib/sqlalchemy/sql/util.py create mode 100644 lib/sqlalchemy/sql/visitors.py (limited to 'lib/sqlalchemy/sql') diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py new file mode 100644 index 0000000..2677441 --- /dev/null +++ b/lib/sqlalchemy/sql/__init__.py @@ -0,0 +1,150 @@ +# sql/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .base import Executable +from .compiler import COLLECT_CARTESIAN_PRODUCTS +from .compiler import FROM_LINTING +from .compiler import NO_LINTING +from .compiler import WARN_LINTING +from .expression import Alias +from .expression import alias +from .expression import all_ +from .expression import and_ +from .expression import any_ +from .expression import asc +from .expression import between +from .expression import bindparam +from .expression import case +from .expression import cast +from .expression import ClauseElement +from .expression import collate +from .expression import column +from .expression import ColumnCollection +from .expression import ColumnElement +from .expression import CompoundSelect +from .expression import cte +from .expression import Delete +from .expression import delete +from .expression import desc +from .expression import distinct +from .expression import except_ +from .expression import except_all +from .expression import exists +from .expression import extract +from .expression import false +from .expression import False_ +from .expression import FromClause +from .expression import func +from .expression import funcfilter +from .expression import Insert +from .expression import insert +from .expression import intersect +from .expression import intersect_all +from .expression import Join +from .expression import join +from .expression import label +from .expression import LABEL_STYLE_DEFAULT +from .expression import LABEL_STYLE_DISAMBIGUATE_ONLY +from .expression import LABEL_STYLE_NONE +from .expression import LABEL_STYLE_TABLENAME_PLUS_COL +from .expression import lambda_stmt +from .expression import LambdaElement +from .expression import lateral +from .expression import literal +from .expression import literal_column +from .expression import modifier +from .expression import not_ +from .expression import null +from .expression import nulls_first +from .expression import nulls_last +from .expression import nullsfirst +from .expression import nullslast +from .expression import or_ +from .expression import outerjoin +from .expression import outparam +from .expression import over +from .expression import quoted_name +from .expression import Select +from .expression import select +from .expression import Selectable +from .expression import StatementLambdaElement +from .expression import Subquery +from .expression import subquery +from .expression import table +from .expression import TableClause +from .expression import TableSample +from .expression import tablesample +from .expression import text +from .expression import true +from .expression import True_ +from .expression import tuple_ +from .expression import type_coerce +from .expression import union +from .expression import union_all +from .expression import Update +from .expression import update +from .expression import Values +from .expression import values +from .expression import within_group +from .visitors import ClauseVisitor + + +def __go(lcls): + global __all__ + from .. import util as _sa_util + + import inspect as _inspect + + __all__ = sorted( + name + for name, obj in lcls.items() + if not (name.startswith("_") or _inspect.ismodule(obj)) + ) + + from .annotation import _prepare_annotations + from .annotation import Annotated + from .elements import AnnotatedColumnElement + from .elements import ClauseList + from .selectable import AnnotatedFromClause + + # from .traversals import _preconfigure_traversals + + from . import base + from . import coercions + from . import elements + from . import events + from . import lambdas + from . import selectable + from . import schema + from . import sqltypes + from . import traversals + from . import type_api + + base.coercions = elements.coercions = coercions + base.elements = elements + base.type_api = type_api + coercions.elements = elements + coercions.lambdas = lambdas + coercions.schema = schema + coercions.selectable = selectable + coercions.sqltypes = sqltypes + coercions.traversals = traversals + + _prepare_annotations(ColumnElement, AnnotatedColumnElement) + _prepare_annotations(FromClause, AnnotatedFromClause) + _prepare_annotations(ClauseList, Annotated) + + # this is expensive at import time; elements that are used can create + # their traversals on demand + # _preconfigure_traversals(ClauseElement) + + _sa_util.preloaded.import_prefix("sqlalchemy.sql") + + from . import naming + + +__go(locals()) diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py new file mode 100644 index 0000000..5c000ed --- /dev/null +++ b/lib/sqlalchemy/sql/annotation.py @@ -0,0 +1,364 @@ +# sql/annotation.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""The :class:`.Annotated` class and related routines; creates hash-equivalent +copies of SQL constructs which contain context-specific markers and +associations. + +""" + +from . import operators +from .base import HasCacheKey +from .traversals import anon_map +from .visitors import InternalTraversal +from .. import util + +EMPTY_ANNOTATIONS = util.immutabledict() + + +class SupportsAnnotations(object): + _annotations = EMPTY_ANNOTATIONS + + @util.memoized_property + def _annotations_cache_key(self): + anon_map_ = anon_map() + return ( + "_annotations", + tuple( + ( + key, + value._gen_cache_key(anon_map_, []) + if isinstance(value, HasCacheKey) + else value, + ) + for key, value in [ + (key, self._annotations[key]) + for key in sorted(self._annotations) + ] + ), + ) + + +class SupportsCloneAnnotations(SupportsAnnotations): + + _clone_annotations_traverse_internals = [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + + def _annotate(self, values): + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + new = self._clone() + new._annotations = new._annotations.union(values) + new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) + return new + + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + new = self._clone() + new._annotations = util.immutabledict(values) + new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) + return new + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`_expression.ClauseElement` + with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone or self._annotations: + # clone is used when we are also copying + # the expression for a deep deannotation + new = self._clone() + new._annotations = util.immutabledict() + new.__dict__.pop("_annotations_cache_key", None) + return new + else: + return self + + +class SupportsWrappingAnnotations(SupportsAnnotations): + def _annotate(self, values): + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + return Annotated(self, values) + + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + return Annotated(self, values) + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`_expression.ClauseElement` + with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone: + s = self._clone() + return s + else: + return self + + +class Annotated(object): + """clones a SupportsAnnotated and applies an 'annotations' dictionary. + + Unlike regular clones, this clone also mimics __hash__() and + __cmp__() of the original element so that it takes its place + in hashed collections. + + A reference to the original element is maintained, for the important + reason of keeping its hash value current. When GC'ed, the + hash value may be reused, causing conflicts. + + .. note:: The rationale for Annotated producing a brand new class, + rather than placing the functionality directly within ClauseElement, + is **performance**. The __hash__() method is absent on plain + ClauseElement which leads to significantly reduced function call + overhead, as the use of sets and dictionaries against ClauseElement + objects is prevalent, but most are not "annotated". + + """ + + _is_column_operators = False + + def __new__(cls, *args): + if not args: + # clone constructor + return object.__new__(cls) + else: + element, values = args + # pull appropriate subclass from registry of annotated + # classes + try: + cls = annotated_classes[element.__class__] + except KeyError: + cls = _new_annotation_type(element.__class__, cls) + return object.__new__(cls) + + def __init__(self, element, values): + self.__dict__ = element.__dict__.copy() + self.__dict__.pop("_annotations_cache_key", None) + self.__dict__.pop("_generate_cache_key", None) + self.__element = element + self._annotations = util.immutabledict(values) + self._hash = hash(element) + + def _annotate(self, values): + _values = self._annotations.union(values) + return self._with_annotations(_values) + + def _with_annotations(self, values): + clone = self.__class__.__new__(self.__class__) + clone.__dict__ = self.__dict__.copy() + clone.__dict__.pop("_annotations_cache_key", None) + clone.__dict__.pop("_generate_cache_key", None) + clone._annotations = values + return clone + + def _deannotate(self, values=None, clone=True): + if values is None: + return self.__element + else: + return self._with_annotations( + util.immutabledict( + { + key: value + for key, value in self._annotations.items() + if key not in values + } + ) + ) + + def _compiler_dispatch(self, visitor, **kw): + return self.__element.__class__._compiler_dispatch(self, visitor, **kw) + + @property + def _constructor(self): + return self.__element._constructor + + def _clone(self, **kw): + clone = self.__element._clone(**kw) + if clone is self.__element: + # detect immutable, don't change anything + return self + else: + # update the clone with any changes that have occurred + # to this object's __dict__. + clone.__dict__.update(self.__dict__) + return self.__class__(clone, self._annotations) + + def __reduce__(self): + return self.__class__, (self.__element, self._annotations) + + def __hash__(self): + return self._hash + + def __eq__(self, other): + if self._is_column_operators: + return self.__element.__class__.__eq__(self, other) + else: + return hash(other) == hash(self) + + @property + def entity_namespace(self): + if "entity_namespace" in self._annotations: + return self._annotations["entity_namespace"].entity_namespace + else: + return self.__element.entity_namespace + + +# hard-generate Annotated subclasses. this technique +# is used instead of on-the-fly types (i.e. type.__new__()) +# so that the resulting objects are pickleable; additionally, other +# decisions can be made up front about the type of object being annotated +# just once per class rather than per-instance. +annotated_classes = {} + + +def _deep_annotate( + element, annotations, exclude=None, detect_subquery_cols=False +): + """Deep copy the given ClauseElement, annotating each element + with the given annotations dictionary. + + Elements within the exclude collection will be cloned but not annotated. + + """ + + # annotated objects hack the __hash__() method so if we want to + # uniquely process them we have to use id() + + cloned_ids = {} + + def clone(elem, **kw): + kw["detect_subquery_cols"] = detect_subquery_cols + id_ = id(elem) + + if id_ in cloned_ids: + return cloned_ids[id_] + + if ( + exclude + and hasattr(elem, "proxy_set") + and elem.proxy_set.intersection(exclude) + ): + newelem = elem._clone(clone=clone, **kw) + elif annotations != elem._annotations: + if detect_subquery_cols and elem._is_immutable: + newelem = elem._clone(clone=clone, **kw)._annotate(annotations) + else: + newelem = elem._annotate(annotations) + else: + newelem = elem + newelem._copy_internals(clone=clone) + cloned_ids[id_] = newelem + return newelem + + if element is not None: + element = clone(element) + clone = None # remove gc cycles + return element + + +def _deep_deannotate(element, values=None): + """Deep copy the given element, removing annotations.""" + + cloned = {} + + def clone(elem, **kw): + if values: + key = id(elem) + else: + key = elem + + if key not in cloned: + newelem = elem._deannotate(values=values, clone=True) + newelem._copy_internals(clone=clone) + cloned[key] = newelem + return newelem + else: + return cloned[key] + + if element is not None: + element = clone(element) + clone = None # remove gc cycles + return element + + +def _shallow_annotate(element, annotations): + """Annotate the given ClauseElement and copy its internals so that + internal objects refer to the new annotated object. + + Basically used to apply a "don't traverse" annotation to a + selectable, without digging throughout the whole + structure wasting time. + """ + element = element._annotate(annotations) + element._copy_internals() + return element + + +def _new_annotation_type(cls, base_cls): + if issubclass(cls, Annotated): + return cls + elif cls in annotated_classes: + return annotated_classes[cls] + + for super_ in cls.__mro__: + # check if an Annotated subclass more specific than + # the given base_cls is already registered, such + # as AnnotatedColumnElement. + if super_ in annotated_classes: + base_cls = annotated_classes[super_] + break + + annotated_classes[cls] = anno_cls = type( + "Annotated%s" % cls.__name__, (base_cls, cls), {} + ) + globals()["Annotated%s" % cls.__name__] = anno_cls + + if "_traverse_internals" in cls.__dict__: + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + elif cls.__dict__.get("inherit_cache", False): + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + + # some classes include this even if they have traverse_internals + # e.g. BindParameter, add it if present. + if cls.__dict__.get("inherit_cache", False): + anno_cls.inherit_cache = True + + anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) + + return anno_cls + + +def _prepare_annotations(target_hierarchy, base_cls): + for cls in util.walk_subclasses(target_hierarchy): + _new_annotation_type(cls, base_cls) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py new file mode 100644 index 0000000..ec685d1 --- /dev/null +++ b/lib/sqlalchemy/sql/base.py @@ -0,0 +1,1702 @@ +# sql/base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Foundational utilities common to many sql modules. + +""" + + +import itertools +import operator +import re + +from . import roles +from . import visitors +from .traversals import HasCacheKey # noqa +from .traversals import HasCopyInternals # noqa +from .traversals import MemoizedHasCacheKey # noqa +from .visitors import ClauseVisitor +from .visitors import ExtendedInternalTraversal +from .visitors import InternalTraversal +from .. import exc +from .. import util +from ..util import HasMemoized +from ..util import hybridmethod + + +coercions = None +elements = None +type_api = None + +PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT") +NO_ARG = util.symbol("NO_ARG") + + +class Immutable(object): + """mark a ClauseElement as 'immutable' when expressions are cloned.""" + + _is_immutable = True + + def unique_params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def _clone(self, **kw): + return self + + def _copy_internals(self, **kw): + pass + + +class SingletonConstant(Immutable): + """Represent SQL constants like NULL, TRUE, FALSE""" + + _is_singleton_constant = True + + def __new__(cls, *arg, **kw): + return cls._singleton + + @classmethod + def _create_singleton(cls): + obj = object.__new__(cls) + obj.__init__() + + # for a long time this was an empty frozenset, meaning + # a SingletonConstant would never be a "corresponding column" in + # a statement. This referred to #6259. However, in #7154 we see + # that we do in fact need "correspondence" to work when matching cols + # in result sets, so the non-correspondence was moved to a more + # specific level when we are actually adapting expressions for SQL + # render only. + obj.proxy_set = frozenset([obj]) + cls._singleton = obj + + +def _from_objects(*elements): + return itertools.chain.from_iterable( + [element._from_objects for element in elements] + ) + + +def _select_iterables(elements): + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain.from_iterable( + [c._select_iterable for c in elements] + ) + + +def _generative(fn): + """non-caching _generative() decorator. + + This is basically the legacy decorator that copies the object and + runs a method on the new copy. + + """ + + @util.decorator + def _generative(fn, self, *args, **kw): + """Mark a method as generative.""" + + self = self._generate() + x = fn(self, *args, **kw) + assert x is None, "generative methods must have no return value" + return self + + decorated = _generative(fn) + decorated.non_generative = fn + return decorated + + +def _exclusive_against(*names, **kw): + msgs = kw.pop("msgs", {}) + + defaults = kw.pop("defaults", {}) + + getters = [ + (name, operator.attrgetter(name), defaults.get(name, None)) + for name in names + ] + + @util.decorator + def check(fn, *args, **kw): + # make pylance happy by not including "self" in the argument + # list + self = args[0] + args = args[1:] + for name, getter, default_ in getters: + if getter(self) is not default_: + msg = msgs.get( + name, + "Method %s() has already been invoked on this %s construct" + % (fn.__name__, self.__class__), + ) + raise exc.InvalidRequestError(msg) + return fn(self, *args, **kw) + + return check + + +def _clone(element, **kw): + return element._clone(**kw) + + +def _expand_cloned(elements): + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + return itertools.chain(*[x._cloned_set for x in elements]) + + +def _cloned_intersection(a, b): + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the entities present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if all_overlap.intersection(elem._cloned_set) + ) + + +def _cloned_difference(a, b): + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return set( + elem for elem in a if not all_overlap.intersection(elem._cloned_set) + ) + + +class _DialectArgView(util.collections_abc.MutableMapping): + """A dictionary view of dialect-level arguments in the form + _. + + """ + + def __init__(self, obj): + self.obj = obj + + def _key(self, key): + try: + dialect, value_key = key.split("_", 1) + except ValueError as err: + util.raise_(KeyError(key), replace_context=err) + else: + return dialect, value_key + + def __getitem__(self, key): + dialect, value_key = self._key(key) + + try: + opt = self.obj.dialect_options[dialect] + except exc.NoSuchModuleError as err: + util.raise_(KeyError(key), replace_context=err) + else: + return opt[value_key] + + def __setitem__(self, key, value): + try: + dialect, value_key = self._key(key) + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Keys must be of the form _" + ), + replace_context=err, + ) + else: + self.obj.dialect_options[dialect][value_key] = value + + def __delitem__(self, key): + dialect, value_key = self._key(key) + del self.obj.dialect_options[dialect][value_key] + + def __len__(self): + return sum( + len(args._non_defaults) + for args in self.obj.dialect_options.values() + ) + + def __iter__(self): + return ( + "%s_%s" % (dialect_name, value_name) + for dialect_name in self.obj.dialect_options + for value_name in self.obj.dialect_options[ + dialect_name + ]._non_defaults + ) + + +class _DialectArgDict(util.collections_abc.MutableMapping): + """A dictionary view of dialect-level arguments for a specific + dialect. + + Maintains a separate collection of user-specified arguments + and dialect-specified default arguments. + + """ + + def __init__(self): + self._non_defaults = {} + self._defaults = {} + + def __len__(self): + return len(set(self._non_defaults).union(self._defaults)) + + def __iter__(self): + return iter(set(self._non_defaults).union(self._defaults)) + + def __getitem__(self, key): + if key in self._non_defaults: + return self._non_defaults[key] + else: + return self._defaults[key] + + def __setitem__(self, key, value): + self._non_defaults[key] = value + + def __delitem__(self, key): + del self._non_defaults[key] + + +@util.preload_module("sqlalchemy.dialects") +def _kw_reg_for_dialect(dialect_name): + dialect_cls = util.preloaded.dialects.registry.load(dialect_name) + if dialect_cls.construct_arguments is None: + return None + return dict(dialect_cls.construct_arguments) + + +class DialectKWArgs(object): + """Establish the ability for a class to have dialect-specific arguments + with defaults and constructor validation. + + The :class:`.DialectKWArgs` interacts with the + :attr:`.DefaultDialect.construct_arguments` present on a dialect. + + .. seealso:: + + :attr:`.DefaultDialect.construct_arguments` + + """ + + _dialect_kwargs_traverse_internals = [ + ("dialect_options", InternalTraversal.dp_dialect_options) + ] + + @classmethod + def argument_for(cls, dialect_name, argument_name, default): + """Add a new kind of dialect-specific keyword argument for this class. + + E.g.:: + + Index.argument_for("mydialect", "length", None) + + some_index = Index('a', 'b', mydialect_length=5) + + The :meth:`.DialectKWArgs.argument_for` method is a per-argument + way adding extra arguments to the + :attr:`.DefaultDialect.construct_arguments` dictionary. This + dictionary provides a list of argument names accepted by various + schema-level constructs on behalf of a dialect. + + New dialects should typically specify this dictionary all at once as a + data member of the dialect class. The use case for ad-hoc addition of + argument names is typically for end-user code that is also using + a custom compilation scheme which consumes the additional arguments. + + :param dialect_name: name of a dialect. The dialect must be + locatable, else a :class:`.NoSuchModuleError` is raised. The + dialect must also include an existing + :attr:`.DefaultDialect.construct_arguments` collection, indicating + that it participates in the keyword-argument validation and default + system, else :class:`.ArgumentError` is raised. If the dialect does + not include this collection, then any keyword argument can be + specified on behalf of this dialect already. All dialects packaged + within SQLAlchemy include this collection, however for third party + dialects, support may vary. + + :param argument_name: name of the parameter. + + :param default: default value of the parameter. + + .. versionadded:: 0.9.4 + + """ + + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + if construct_arg_dictionary is None: + raise exc.ArgumentError( + "Dialect '%s' does have keyword-argument " + "validation and defaults enabled configured" % dialect_name + ) + if cls not in construct_arg_dictionary: + construct_arg_dictionary[cls] = {} + construct_arg_dictionary[cls][argument_name] = default + + @util.memoized_property + def dialect_kwargs(self): + """A collection of keyword arguments specified as dialect-specific + options to this construct. + + The arguments are present here in their original ``_`` + format. Only arguments that were actually passed are included; + unlike the :attr:`.DialectKWArgs.dialect_options` collection, which + contains all options known by this dialect including defaults. + + The collection is also writable; keys are accepted of the + form ``_`` where the value will be assembled + into the list of options. + + .. versionadded:: 0.9.2 + + .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs` + collection is now writable. + + .. seealso:: + + :attr:`.DialectKWArgs.dialect_options` - nested dictionary form + + """ + return _DialectArgView(self) + + @property + def kwargs(self): + """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`.""" + return self.dialect_kwargs + + _kw_registry = util.PopulateDict(_kw_reg_for_dialect) + + def _kw_reg_for_dialect_cls(self, dialect_name): + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + d = _DialectArgDict() + + if construct_arg_dictionary is None: + d._defaults.update({"*": None}) + else: + for cls in reversed(self.__class__.__mro__): + if cls in construct_arg_dictionary: + d._defaults.update(construct_arg_dictionary[cls]) + return d + + @util.memoized_property + def dialect_options(self): + """A collection of keyword arguments specified as dialect-specific + options to this construct. + + This is a two-level nested registry, keyed to ```` + and ````. For example, the ``postgresql_where`` + argument would be locatable as:: + + arg = my_object.dialect_options['postgresql']['where'] + + .. versionadded:: 0.9.2 + + .. seealso:: + + :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form + + """ + + return util.PopulateDict( + util.portable_instancemethod(self._kw_reg_for_dialect_cls) + ) + + def _validate_dialect_kwargs(self, kwargs): + # validate remaining kwargs that they all specify DB prefixes + + if not kwargs: + return + + for k in kwargs: + m = re.match("^(.+?)_(.+)$", k) + if not m: + raise TypeError( + "Additional arguments should be " + "named _, got '%s'" % k + ) + dialect_name, arg_name = m.group(1, 2) + + try: + construct_arg_dictionary = self.dialect_options[dialect_name] + except exc.NoSuchModuleError: + util.warn( + "Can't validate argument %r; can't " + "locate any SQLAlchemy dialect named %r" + % (k, dialect_name) + ) + self.dialect_options[dialect_name] = d = _DialectArgDict() + d._defaults.update({"*": None}) + d._non_defaults[arg_name] = kwargs[k] + else: + if ( + "*" not in construct_arg_dictionary + and arg_name not in construct_arg_dictionary + ): + raise exc.ArgumentError( + "Argument %r is not accepted by " + "dialect %r on behalf of %r" + % (k, dialect_name, self.__class__) + ) + else: + construct_arg_dictionary[arg_name] = kwargs[k] + + +class CompileState(object): + """Produces additional object state necessary for a statement to be + compiled. + + the :class:`.CompileState` class is at the base of classes that assemble + state for a particular statement object that is then used by the + compiler. This process is essentially an extension of the process that + the SQLCompiler.visit_XYZ() method takes, however there is an emphasis + on converting raw user intent into more organized structures rather than + producing string output. The top-level :class:`.CompileState` for the + statement being executed is also accessible when the execution context + works with invoking the statement and collecting results. + + The production of :class:`.CompileState` is specific to the compiler, such + as within the :meth:`.SQLCompiler.visit_insert`, + :meth:`.SQLCompiler.visit_select` etc. methods. These methods are also + responsible for associating the :class:`.CompileState` with the + :class:`.SQLCompiler` itself, if the statement is the "toplevel" statement, + i.e. the outermost SQL statement that's actually being executed. + There can be other :class:`.CompileState` objects that are not the + toplevel, such as when a SELECT subquery or CTE-nested + INSERT/UPDATE/DELETE is generated. + + .. versionadded:: 1.4 + + """ + + __slots__ = ("statement",) + + plugins = {} + + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + # factory construction. + + if statement._propagate_attrs: + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + klass = cls.plugins.get( + (plugin_name, statement._effective_plugin_target), None + ) + if klass is None: + klass = cls.plugins[ + ("default", statement._effective_plugin_target) + ] + + else: + klass = cls.plugins[ + ("default", statement._effective_plugin_target) + ] + + if klass is cls: + return cls(statement, compiler, **kw) + else: + return klass.create_for_statement(statement, compiler, **kw) + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + @classmethod + def get_plugin_class(cls, statement): + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", None + ) + + if plugin_name: + key = (plugin_name, statement._effective_plugin_target) + if key in cls.plugins: + return cls.plugins[key] + + # there's no case where we call upon get_plugin_class() and want + # to get None back, there should always be a default. return that + # if there was no plugin-specific class (e.g. Insert with "orm" + # plugin) + try: + return cls.plugins[("default", statement._effective_plugin_target)] + except KeyError: + return None + + @classmethod + def _get_plugin_class_for_plugin(cls, statement, plugin_name): + try: + return cls.plugins[ + (plugin_name, statement._effective_plugin_target) + ] + except KeyError: + return None + + @classmethod + def plugin_for(cls, plugin_name, visit_name): + def decorate(cls_to_decorate): + cls.plugins[(plugin_name, visit_name)] = cls_to_decorate + return cls_to_decorate + + return decorate + + +class Generative(HasMemoized): + """Provide a method-chaining pattern in conjunction with the + @_generative decorator.""" + + def _generate(self): + skip = self._memoized_keys + cls = self.__class__ + s = cls.__new__(cls) + if skip: + # ensure this iteration remains atomic + s.__dict__ = { + k: v for k, v in self.__dict__.copy().items() if k not in skip + } + else: + s.__dict__ = self.__dict__.copy() + return s + + +class InPlaceGenerative(HasMemoized): + """Provide a method-chaining pattern in conjunction with the + @_generative decorator that mutates in place.""" + + def _generate(self): + skip = self._memoized_keys + for k in skip: + self.__dict__.pop(k, None) + return self + + +class HasCompileState(Generative): + """A class that has a :class:`.CompileState` associated with it.""" + + _compile_state_plugin = None + + _attributes = util.immutabledict() + + _compile_state_factory = CompileState.create_for_statement + + +class _MetaOptions(type): + """metaclass for the Options class.""" + + def __init__(cls, classname, bases, dict_): + cls._cache_attrs = tuple( + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) + ) + type.__init__(cls, classname, bases, dict_) + + def __add__(self, other): + o1 = self() + + if set(other).difference(self._cache_attrs): + raise TypeError( + "dictionary contains attributes not covered by " + "Options class %s: %r" + % (self, set(other).difference(self._cache_attrs)) + ) + + o1.__dict__.update(other) + return o1 + + +class Options(util.with_metaclass(_MetaOptions)): + """A cacheable option dictionary with defaults.""" + + def __init__(self, **kw): + self.__dict__.update(kw) + + def __add__(self, other): + o1 = self.__class__.__new__(self.__class__) + o1.__dict__.update(self.__dict__) + + if set(other).difference(self._cache_attrs): + raise TypeError( + "dictionary contains attributes not covered by " + "Options class %s: %r" + % (self, set(other).difference(self._cache_attrs)) + ) + + o1.__dict__.update(other) + return o1 + + def __eq__(self, other): + # TODO: very inefficient. This is used only in test suites + # right now. + for a, b in util.zip_longest(self._cache_attrs, other._cache_attrs): + if getattr(self, a) != getattr(other, b): + return False + return True + + def __repr__(self): + # TODO: fairly inefficient, used only in debugging right now. + + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join( + "%s=%r" % (k, self.__dict__[k]) + for k in self._cache_attrs + if k in self.__dict__ + ), + ) + + @classmethod + def isinstance(cls, klass): + return issubclass(cls, klass) + + @hybridmethod + def add_to_element(self, name, value): + return self + {name: getattr(self, name) + value} + + @hybridmethod + def _state_dict(self): + return self.__dict__ + + _state_dict_const = util.immutabledict() + + @_state_dict.classlevel + def _state_dict(cls): + return cls._state_dict_const + + @classmethod + def safe_merge(cls, other): + d = other._state_dict() + + # only support a merge with another object of our class + # and which does not have attrs that we don't. otherwise + # we risk having state that might not be part of our cache + # key strategy + + if ( + cls is not other.__class__ + and other._cache_attrs + and set(other._cache_attrs).difference(cls._cache_attrs) + ): + raise TypeError( + "other element %r is not empty, is not of type %s, " + "and contains attributes not covered here %r" + % ( + other, + cls, + set(other._cache_attrs).difference(cls._cache_attrs), + ) + ) + return cls + d + + @classmethod + def from_execution_options( + cls, key, attrs, exec_options, statement_exec_options + ): + """process Options argument in terms of execution options. + + + e.g.:: + + ( + load_options, + execution_options, + ) = QueryContext.default_load_options.from_execution_options( + "_sa_orm_load_options", + { + "populate_existing", + "autoflush", + "yield_per" + }, + execution_options, + statement._execution_options, + ) + + get back the Options and refresh "_sa_orm_load_options" in the + exec options dict w/ the Options as well + + """ + + # common case is that no options we are looking for are + # in either dictionary, so cancel for that first + check_argnames = attrs.intersection( + set(exec_options).union(statement_exec_options) + ) + + existing_options = exec_options.get(key, cls) + + if check_argnames: + result = {} + for argname in check_argnames: + local = "_" + argname + if argname in exec_options: + result[local] = exec_options[argname] + elif argname in statement_exec_options: + result[local] = statement_exec_options[argname] + + new_options = existing_options + result + exec_options = util.immutabledict().merge_with( + exec_options, {key: new_options} + ) + return new_options, exec_options + + else: + return existing_options, exec_options + + +class CacheableOptions(Options, HasCacheKey): + @hybridmethod + def _gen_cache_key(self, anon_map, bindparams): + return HasCacheKey._gen_cache_key(self, anon_map, bindparams) + + @_gen_cache_key.classlevel + def _gen_cache_key(cls, anon_map, bindparams): + return (cls, ()) + + @hybridmethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key_for_object(self) + + +class ExecutableOption(HasCopyInternals): + _annotations = util.EMPTY_DICT + + __visit_name__ = "executable_option" + + _is_has_cache_key = False + + def _clone(self, **kw): + """Create a shallow copy of this ExecutableOption.""" + c = self.__class__.__new__(self.__class__) + c.__dict__ = dict(self.__dict__) + return c + + +class Executable(roles.StatementRole, Generative): + """Mark a :class:`_expression.ClauseElement` as supporting execution. + + :class:`.Executable` is a superclass for all "statement" types + of objects, including :func:`select`, :func:`delete`, :func:`update`, + :func:`insert`, :func:`text`. + + """ + + supports_execution = True + _execution_options = util.immutabledict() + _bind = None + _with_options = () + _with_context_options = () + + _executable_traverse_internals = [ + ("_with_options", InternalTraversal.dp_executable_options), + ( + "_with_context_options", + ExtendedInternalTraversal.dp_with_context_options, + ), + ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs), + ] + + is_select = False + is_update = False + is_insert = False + is_text = False + is_delete = False + is_dml = False + + @property + def _effective_plugin_target(self): + return self.__visit_name__ + + @_generative + def options(self, *options): + """Apply options to this statement. + + In the general sense, options are any kind of Python object + that can be interpreted by the SQL compiler for the statement. + These options can be consumed by specific dialects or specific kinds + of compilers. + + The most commonly known kind of option are the ORM level options + that apply "eager load" and other loading behaviors to an ORM + query. However, options can theoretically be used for many other + purposes. + + For background on specific kinds of options for specific kinds of + statements, refer to the documentation for those option objects. + + .. versionchanged:: 1.4 - added :meth:`.Generative.options` to + Core statement objects towards the goal of allowing unified + Core / ORM querying capabilities. + + .. seealso:: + + :ref:`deferred_options` - refers to options specific to the usage + of ORM queries + + :ref:`relationship_loader_options` - refers to options specific + to the usage of ORM queries + + """ + self._with_options += tuple( + coercions.expect(roles.ExecutableOptionRole, opt) + for opt in options + ) + + @_generative + def _set_compile_options(self, compile_options): + """Assign the compile options to a new value. + + :param compile_options: appropriate CacheableOptions structure + + """ + + self._compile_options = compile_options + + @_generative + def _update_compile_options(self, options): + """update the _compile_options with new keys.""" + + self._compile_options += options + + @_generative + def _add_context_option(self, callable_, cache_args): + """Add a context option to this statement. + + These are callable functions that will + be given the CompileState object upon compilation. + + A second argument cache_args is required, which will be combined with + the ``__code__`` identity of the function itself in order to produce a + cache key. + + """ + self._with_context_options += ((callable_, cache_args),) + + @_generative + def execution_options(self, **kw): + """Set non-SQL options for the statement which take effect during + execution. + + Execution options can be set on a per-statement or + per :class:`_engine.Connection` basis. Additionally, the + :class:`_engine.Engine` and ORM :class:`~.orm.query.Query` + objects provide + access to execution options which they in turn configure upon + connections. + + The :meth:`execution_options` method is generative. A new + instance of this statement is returned that contains the options:: + + statement = select(table.c.x, table.c.y) + statement = statement.execution_options(autocommit=True) + + Note that only a subset of possible execution options can be applied + to a statement - these include "autocommit" and "stream_results", + but not "isolation_level" or "compiled_cache". + See :meth:`_engine.Connection.execution_options` for a full list of + possible options. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + + :meth:`_query.Query.execution_options` + + :meth:`.Executable.get_execution_options` + + """ + if "isolation_level" in kw: + raise exc.ArgumentError( + "'isolation_level' execution option may only be specified " + "on Connection.execution_options(), or " + "per-engine using the isolation_level " + "argument to create_engine()." + ) + if "compiled_cache" in kw: + raise exc.ArgumentError( + "'compiled_cache' execution option may only be specified " + "on Connection.execution_options(), not per statement." + ) + self._execution_options = self._execution_options.union(kw) + + def get_execution_options(self): + """Get the non-SQL options which will take effect during execution. + + .. versionadded:: 1.3 + + .. seealso:: + + :meth:`.Executable.execution_options` + """ + return self._execution_options + + @util.deprecated_20( + ":meth:`.Executable.execute`", + alternative="All statement execution in SQLAlchemy 2.0 is performed " + "by the :meth:`_engine.Connection.execute` method of " + ":class:`_engine.Connection`, " + "or in the ORM by the :meth:`.Session.execute` method of " + ":class:`.Session`.", + ) + def execute(self, *multiparams, **params): + """Compile and execute this :class:`.Executable`.""" + e = self.bind + if e is None: + label = ( + getattr(self, "description", None) or self.__class__.__name__ + ) + msg = ( + "This %s is not directly bound to a Connection or Engine. " + "Use the .execute() method of a Connection or Engine " + "to execute this construct." % label + ) + raise exc.UnboundExecutionError(msg) + return e._execute_clauseelement( + self, multiparams, params, util.immutabledict() + ) + + @util.deprecated_20( + ":meth:`.Executable.scalar`", + alternative="Scalar execution in SQLAlchemy 2.0 is performed " + "by the :meth:`_engine.Connection.scalar` method of " + ":class:`_engine.Connection`, " + "or in the ORM by the :meth:`.Session.scalar` method of " + ":class:`.Session`.", + ) + def scalar(self, *multiparams, **params): + """Compile and execute this :class:`.Executable`, returning the + result's scalar representation. + + """ + return self.execute(*multiparams, **params).scalar() + + @property + @util.deprecated_20( + ":attr:`.Executable.bind`", + alternative="Bound metadata is being removed as of SQLAlchemy 2.0.", + enable_warnings=False, + ) + def bind(self): + """Returns the :class:`_engine.Engine` or :class:`_engine.Connection` + to + which this :class:`.Executable` is bound, or None if none found. + + This is a traversal which checks locally, then + checks among the "from" clauses of associated objects + until a bound engine or connection is found. + + """ + if self._bind is not None: + return self._bind + + for f in _from_objects(self): + if f is self: + continue + engine = f.bind + if engine is not None: + return engine + else: + return None + + +class prefix_anon_map(dict): + """A map that creates new keys for missing key access. + + Considers keys of the form " " to produce + new symbols "_", where "index" is an incrementing integer + corresponding to . + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key): + (ident, derived) = key.split(" ", 1) + anonymous_counter = self.get(derived, 1) + self[derived] = anonymous_counter + 1 + value = derived + "_" + str(anonymous_counter) + self[key] = value + return value + + +class SchemaEventTarget(object): + """Base class for elements that are the targets of :class:`.DDLEvents` + events. + + This includes :class:`.SchemaItem` as well as :class:`.SchemaType`. + + """ + + def _set_parent(self, parent, **kw): + """Associate with this SchemaEvent's parent object.""" + + def _set_parent_with_dispatch(self, parent, **kw): + self.dispatch.before_parent_attach(self, parent) + self._set_parent(parent, **kw) + self.dispatch.after_parent_attach(self, parent) + + +class SchemaVisitor(ClauseVisitor): + """Define the visiting for ``SchemaItem`` objects.""" + + __traverse_options__ = {"schema_visitor": True} + + +class ColumnCollection(object): + """Collection of :class:`_expression.ColumnElement` instances, + typically for + :class:`_sql.FromClause` objects. + + The :class:`_sql.ColumnCollection` object is most commonly available + as the :attr:`_schema.Table.c` or :attr:`_schema.Table.columns` collection + on the :class:`_schema.Table` object, introduced at + :ref:`metadata_tables_and_columns`. + + The :class:`_expression.ColumnCollection` has both mapping- and sequence- + like behaviors. A :class:`_expression.ColumnCollection` usually stores + :class:`_schema.Column` objects, which are then accessible both via mapping + style access as well as attribute access style. + + To access :class:`_schema.Column` objects using ordinary attribute-style + access, specify the name like any other object attribute, such as below + a column named ``employee_name`` is accessed:: + + >>> employee_table.c.employee_name + + To access columns that have names with special characters or spaces, + index-style access is used, such as below which illustrates a column named + ``employee ' payment`` is accessed:: + + >>> employee_table.c["employee ' payment"] + + As the :class:`_sql.ColumnCollection` object provides a Python dictionary + interface, common dictionary method names like + :meth:`_sql.ColumnCollection.keys`, :meth:`_sql.ColumnCollection.values`, + and :meth:`_sql.ColumnCollection.items` are available, which means that + database columns that are keyed under these names also need to use indexed + access:: + + >>> employee_table.c["values"] + + + The name for which a :class:`_schema.Column` would be present is normally + that of the :paramref:`_schema.Column.key` parameter. In some contexts, + such as a :class:`_sql.Select` object that uses a label style set + using the :meth:`_sql.Select.set_label_style` method, a column of a certain + key may instead be represented under a particular label name such + as ``tablename_columnname``:: + + >>> from sqlalchemy import select, column, table + >>> from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL + >>> t = table("t", column("c")) + >>> stmt = select(t).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + >>> subq = stmt.subquery() + >>> subq.c.t_c + + + :class:`.ColumnCollection` also indexes the columns in order and allows + them to be accessible by their integer position:: + + >>> cc[0] + Column('x', Integer(), table=None) + >>> cc[1] + Column('y', Integer(), table=None) + + .. versionadded:: 1.4 :class:`_expression.ColumnCollection` + allows integer-based + index access to the collection. + + Iterating the collection yields the column expressions in order:: + + >>> list(cc) + [Column('x', Integer(), table=None), + Column('y', Integer(), table=None)] + + The base :class:`_expression.ColumnCollection` object can store + duplicates, which can + mean either two columns with the same key, in which case the column + returned by key access is **arbitrary**:: + + >>> x1, x2 = Column('x', Integer), Column('x', Integer) + >>> cc = ColumnCollection(columns=[(x1.name, x1), (x2.name, x2)]) + >>> list(cc) + [Column('x', Integer(), table=None), + Column('x', Integer(), table=None)] + >>> cc['x'] is x1 + False + >>> cc['x'] is x2 + True + + Or it can also mean the same column multiple times. These cases are + supported as :class:`_expression.ColumnCollection` + is used to represent the columns in + a SELECT statement which may include duplicates. + + A special subclass :class:`.DedupeColumnCollection` exists which instead + maintains SQLAlchemy's older behavior of not allowing duplicates; this + collection is used for schema level objects like :class:`_schema.Table` + and + :class:`.PrimaryKeyConstraint` where this deduping is helpful. The + :class:`.DedupeColumnCollection` class also has additional mutation methods + as the schema constructs have more use cases that require removal and + replacement of columns. + + .. versionchanged:: 1.4 :class:`_expression.ColumnCollection` + now stores duplicate + column keys as well as the same column in multiple positions. The + :class:`.DedupeColumnCollection` class is added to maintain the + former behavior in those cases where deduplication as well as + additional replace/remove operations are needed. + + + """ + + __slots__ = "_collection", "_index", "_colset" + + def __init__(self, columns=None): + object.__setattr__(self, "_colset", set()) + object.__setattr__(self, "_index", {}) + object.__setattr__(self, "_collection", []) + if columns: + self._initial_populate(columns) + + def _initial_populate(self, iter_): + self._populate_separate_keys(iter_) + + @property + def _all_columns(self): + return [col for (k, col) in self._collection] + + def keys(self): + """Return a sequence of string key names for all columns in this + collection.""" + return [k for (k, col) in self._collection] + + def values(self): + """Return a sequence of :class:`_sql.ColumnClause` or + :class:`_schema.Column` objects for all columns in this + collection.""" + return [col for (k, col) in self._collection] + + def items(self): + """Return a sequence of (key, column) tuples for all columns in this + collection each consisting of a string key name and a + :class:`_sql.ColumnClause` or + :class:`_schema.Column` object. + """ + + return list(self._collection) + + def __bool__(self): + return bool(self._collection) + + def __len__(self): + return len(self._collection) + + def __iter__(self): + # turn to a list first to maintain over a course of changes + return iter([col for k, col in self._collection]) + + def __getitem__(self, key): + try: + return self._index[key] + except KeyError as err: + if isinstance(key, util.int_types): + util.raise_(IndexError(key), replace_context=err) + else: + raise + + def __getattr__(self, key): + try: + return self._index[key] + except KeyError as err: + util.raise_(AttributeError(key), replace_context=err) + + def __contains__(self, key): + if key not in self._index: + if not isinstance(key, util.string_types): + raise exc.ArgumentError( + "__contains__ requires a string argument" + ) + return False + else: + return True + + def compare(self, other): + """Compare this :class:`_expression.ColumnCollection` to another + based on the names of the keys""" + + for l, r in util.zip_longest(self, other): + if l is not r: + return False + else: + return True + + def __eq__(self, other): + return self.compare(other) + + def get(self, key, default=None): + """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object + based on a string key name from this + :class:`_expression.ColumnCollection`.""" + + if key in self._index: + return self._index[key] + else: + return default + + def __str__(self): + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join(str(c) for c in self), + ) + + def __setitem__(self, key, value): + raise NotImplementedError() + + def __delitem__(self, key): + raise NotImplementedError() + + def __setattr__(self, key, obj): + raise NotImplementedError() + + def clear(self): + """Dictionary clear() is not implemented for + :class:`_sql.ColumnCollection`.""" + raise NotImplementedError() + + def remove(self, column): + """Dictionary remove() is not implemented for + :class:`_sql.ColumnCollection`.""" + raise NotImplementedError() + + def update(self, iter_): + """Dictionary update() is not implemented for + :class:`_sql.ColumnCollection`.""" + raise NotImplementedError() + + __hash__ = None + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) + self._collection[:] = cols + self._colset.update(c for k, c in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + self._index.update({k: col for k, col in reversed(self._collection)}) + + def add(self, column, key=None): + """Add a column to this :class:`_sql.ColumnCollection`. + + .. note:: + + This method is **not normally used by user-facing code**, as the + :class:`_sql.ColumnCollection` is usually part of an existing + object such as a :class:`_schema.Table`. To add a + :class:`_schema.Column` to an existing :class:`_schema.Table` + object, use the :meth:`_schema.Table.append_column` method. + + """ + if key is None: + key = column.key + + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + if key not in self._index: + self._index[key] = column + + def __getstate__(self): + return {"_collection": self._collection, "_index": self._index} + + def __setstate__(self, state): + object.__setattr__(self, "_index", state["_index"]) + object.__setattr__(self, "_collection", state["_collection"]) + object.__setattr__( + self, "_colset", {col for k, col in self._collection} + ) + + def contains_column(self, col): + """Checks if a column object exists in this collection""" + if col not in self._colset: + if isinstance(col, util.string_types): + raise exc.ArgumentError( + "contains_column cannot be used with string arguments. " + "Use ``col_name in table.c`` instead." + ) + return False + else: + return True + + def as_immutable(self): + """Return an "immutable" form of this + :class:`_sql.ColumnCollection`.""" + + return ImmutableColumnCollection(self) + + def corresponding_column(self, column, require_embedded=False): + """Given a :class:`_expression.ColumnElement`, return the exported + :class:`_expression.ColumnElement` object from this + :class:`_expression.ColumnCollection` + which corresponds to that original :class:`_expression.ColumnElement` + via a common + ancestor column. + + :param column: the target :class:`_expression.ColumnElement` + to be matched. + + :param require_embedded: only return corresponding columns for + the given :class:`_expression.ColumnElement`, if the given + :class:`_expression.ColumnElement` + is actually present within a sub-element + of this :class:`_expression.Selectable`. + Normally the column will match if + it merely shares a common ancestor with one of the exported + columns of this :class:`_expression.Selectable`. + + .. seealso:: + + :meth:`_expression.Selectable.corresponding_column` + - invokes this method + against the collection returned by + :attr:`_expression.Selectable.exported_columns`. + + .. versionchanged:: 1.4 the implementation for ``corresponding_column`` + was moved onto the :class:`_expression.ColumnCollection` itself. + + """ + + def embedded(expanded_proxy_set, target_set): + for t in target_set.difference(expanded_proxy_set): + if not set(_expand_cloned([t])).intersection( + expanded_proxy_set + ): + return False + return True + + # don't dig around if the column is locally present + if column in self._colset: + return column + col, intersect = None, None + target_set = column.proxy_set + cols = [c for (k, c) in self._collection] + for c in cols: + expanded_proxy_set = set(_expand_cloned(c.proxy_set)) + i = target_set.intersection(expanded_proxy_set) + if i and ( + not require_embedded + or embedded(expanded_proxy_set, target_set) + ): + if col is None: + + # no corresponding column yet, pick this one. + + col, intersect = c, i + elif len(i) > len(intersect): + + # 'c' has a larger field of correspondence than + # 'col'. i.e. selectable.c.a1_x->a1.c.x->table.c.x + # matches a1.c.x->table.c.x better than + # selectable.c.x->table.c.x does. + + col, intersect = c, i + elif i == intersect: + # they have the same field of correspondence. see + # which proxy_set has fewer columns in it, which + # indicates a closer relationship with the root + # column. Also take into account the "weight" + # attribute which CompoundSelect() uses to give + # higher precedence to columns based on vertical + # position in the compound statement, and discard + # columns that have no reference to the target + # column (also occurs with CompoundSelect) + + col_distance = util.reduce( + operator.add, + [ + sc._annotations.get("weight", 1) + for sc in col._uncached_proxy_set() + if sc.shares_lineage(column) + ], + ) + c_distance = util.reduce( + operator.add, + [ + sc._annotations.get("weight", 1) + for sc in c._uncached_proxy_set() + if sc.shares_lineage(column) + ], + ) + if c_distance < col_distance: + col, intersect = c, i + return col + + +class DedupeColumnCollection(ColumnCollection): + """A :class:`_expression.ColumnCollection` + that maintains deduplicating behavior. + + This is useful by schema level objects such as :class:`_schema.Table` and + :class:`.PrimaryKeyConstraint`. The collection includes more + sophisticated mutator methods as well to suit schema objects which + require mutable column collections. + + .. versionadded:: 1.4 + + """ + + def add(self, column, key=None): + + if key is not None and column.key != key: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + key = column.key + + if key is None: + raise exc.ArgumentError( + "Can't add unnamed column to column collection" + ) + + if key in self._index: + + existing = self._index[key] + + if existing is column: + return + + self.replace(column) + + # pop out memoized proxy_set as this + # operation may very well be occurring + # in a _make_proxy operation + util.memoized_property.reset(column, "proxy_set") + else: + l = len(self._collection) + self._collection.append((key, column)) + self._colset.add(column) + self._index[l] = column + self._index[key] = column + + def _populate_separate_keys(self, iter_): + """populate from an iterator of (key, column)""" + cols = list(iter_) + + replace_col = [] + for k, col in cols: + if col.key != k: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + if col.name in self._index and col.key != col.name: + replace_col.append(col) + elif col.key in self._index: + replace_col.append(col) + else: + self._index[k] = col + self._collection.append((k, col)) + self._colset.update(c for (k, c) in self._collection) + self._index.update( + (idx, c) for idx, (k, c) in enumerate(self._collection) + ) + for col in replace_col: + self.replace(col) + + def extend(self, iter_): + self._populate_separate_keys((col.key, col) for col in iter_) + + def remove(self, column): + if column not in self._colset: + raise ValueError( + "Can't remove column %r; column is not in this collection" + % column + ) + del self._index[column.key] + self._colset.remove(column) + self._collection[:] = [ + (k, c) for (k, c) in self._collection if c is not column + ] + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} + ) + # delete higher index + del self._index[len(self._collection)] + + def replace(self, column): + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. + + e.g.:: + + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) + + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. + + Used by schema.Column to override columns during table reflection. + + """ + + remove_col = set() + # remove up to two columns based on matches of name as well as key + if column.name in self._index and column.key != column.name: + other = self._index[column.name] + if other.name == other.key: + remove_col.add(other) + + if column.key in self._index: + remove_col.add(self._index[column.key]) + + new_cols = [] + replaced = False + for k, col in self._collection: + if col in remove_col: + if not replaced: + replaced = True + new_cols.append((column.key, column)) + else: + new_cols.append((k, col)) + + if remove_col: + self._colset.difference_update(remove_col) + + if not replaced: + new_cols.append((column.key, column)) + + self._colset.add(column) + self._collection[:] = new_cols + + self._index.clear() + self._index.update( + {idx: col for idx, (k, col) in enumerate(self._collection)} + ) + self._index.update(self._collection) + + +class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection): + __slots__ = ("_parent",) + + def __init__(self, collection): + object.__setattr__(self, "_parent", collection) + object.__setattr__(self, "_colset", collection._colset) + object.__setattr__(self, "_index", collection._index) + object.__setattr__(self, "_collection", collection._collection) + + def __getstate__(self): + return {"_parent": self._parent} + + def __setstate__(self, state): + parent = state["_parent"] + self.__init__(parent) + + add = extend = remove = util.ImmutableContainer._immutable + + +class ColumnSet(util.ordered_column_set): + def contains_column(self, col): + return col in self + + def extend(self, cols): + for col in cols: + self.add(col) + + def __add__(self, other): + return list(self) + list(other) + + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c == local) + return elements.and_(*l) + + def __hash__(self): + return hash(tuple(x for x in self)) + + +def _bind_or_error(schemaitem, msg=None): + + util.warn_deprecated_20( + "The ``bind`` argument for schema methods that invoke SQL " + "against an engine or connection will be required in SQLAlchemy 2.0." + ) + bind = schemaitem.bind + if not bind: + name = schemaitem.__class__.__name__ + label = getattr( + schemaitem, "fullname", getattr(schemaitem, "name", None) + ) + if label: + item = "%s object %r" % (name, label) + else: + item = "%s object" % name + if msg is None: + msg = ( + "%s is not bound to an Engine or Connection. " + "Execution can not proceed without a database to execute " + "against." % item + ) + raise exc.UnboundExecutionError(msg) + return bind + + +def _entity_namespace(entity): + """Return the nearest .entity_namespace for the given entity. + + If not immediately available, does an iterate to find a sub-element + that has one, if any. + + """ + try: + return entity.entity_namespace + except AttributeError: + for elem in visitors.iterate(entity): + if hasattr(elem, "entity_namespace"): + return elem.entity_namespace + else: + raise + + +def _entity_namespace_key(entity, key, default=NO_ARG): + """Return an entry from an entity_namespace. + + + Raises :class:`_exc.InvalidRequestError` rather than attribute error + on not found. + + """ + + try: + ns = _entity_namespace(entity) + if default is not NO_ARG: + return getattr(ns, key, default) + else: + return getattr(ns, key) + except AttributeError as err: + util.raise_( + exc.InvalidRequestError( + 'Entity namespace for "%s" has no property "%s"' + % (entity, key) + ), + replace_context=err, + ) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py new file mode 100644 index 0000000..8cc73cb --- /dev/null +++ b/lib/sqlalchemy/sql/coercions.py @@ -0,0 +1,1096 @@ +# sql/coercions.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import numbers +import re + +from . import operators +from . import roles +from . import visitors +from .base import ExecutableOption +from .base import Options +from .traversals import HasCacheKey +from .visitors import Visitable +from .. import exc +from .. import inspection +from .. import util +from ..util import collections_abc + + +elements = None +lambdas = None +schema = None +selectable = None +sqltypes = None +traversals = None + + +def _is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + """ + + return ( + not isinstance( + element, + (Visitable, schema.SchemaEventTarget), + ) + and not hasattr(element, "__clause_element__") + ) + + +def _deep_is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + does a deeper more esoteric check than _is_literal. is used + for lambda elements that have to distinguish values that would + be bound vs. not without any context. + + """ + + if isinstance(element, collections_abc.Sequence) and not isinstance( + element, str + ): + for elem in element: + if not _deep_is_literal(elem): + return False + else: + return True + + return ( + not isinstance( + element, + ( + Visitable, + schema.SchemaEventTarget, + HasCacheKey, + Options, + util.langhelpers._symbol, + ), + ) + and not hasattr(element, "__clause_element__") + and ( + not isinstance(element, type) + or not issubclass(element, HasCacheKey) + ) + ) + + +def _document_text_coercion(paramname, meth_rst, param_rst): + return util.add_parameter_text( + paramname, + ( + ".. warning:: " + "The %s argument to %s can be passed as a Python string argument, " + "which will be treated " + "as **trusted SQL text** and rendered as given. **DO NOT PASS " + "UNTRUSTED INPUT TO THIS PARAMETER**." + ) + % (param_rst, meth_rst), + ) + + +def _expression_collection_was_a_list(attrname, fnname, args): + if args and isinstance(args[0], (list, set, dict)) and len(args) == 1: + if isinstance(args[0], list): + util.warn_deprecated_20( + 'The "%s" argument to %s(), when referring to a sequence ' + "of items, is now passed as a series of positional " + "elements, rather than as a list. " % (attrname, fnname) + ) + return args[0] + else: + return args + + +def expect( + role, + element, + apply_propagate_attrs=None, + argname=None, + post_inspect=False, + **kw +): + if ( + role.allows_lambda + # note callable() will not invoke a __getattr__() method, whereas + # hasattr(obj, "__call__") will. by keeping the callable() check here + # we prevent most needless calls to hasattr() and therefore + # __getattr__(), which is present on ColumnElement. + and callable(element) + and hasattr(element, "__code__") + ): + return lambdas.LambdaElement( + element, + role, + lambdas.LambdaOptions(**kw), + apply_propagate_attrs=apply_propagate_attrs, + ) + + # major case is that we are given a ClauseElement already, skip more + # elaborate logic up front if possible + impl = _impl_lookup[role] + + original_element = element + + if not isinstance( + element, + (elements.ClauseElement, schema.SchemaItem, schema.FetchedValue), + ): + resolved = None + + if impl._resolve_literal_only: + resolved = impl._literal_coercion(element, **kw) + else: + + original_element = element + + is_clause_element = False + + # this is a special performance optimization for ORM + # joins used by JoinTargetImpl that we don't go through the + # work of creating __clause_element__() when we only need the + # original QueryableAttribute, as the former will do clause + # adaption and all that which is just thrown away here. + if ( + impl._skip_clauseelement_for_target_match + and isinstance(element, role) + and hasattr(element, "__clause_element__") + ): + is_clause_element = True + else: + while hasattr(element, "__clause_element__"): + is_clause_element = True + + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break + + if not is_clause_element: + if impl._use_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + if post_inspect: + insp._post_inspect + try: + resolved = insp.__clause_element__() + except AttributeError: + impl._raise_for_expected(original_element, argname) + + if resolved is None: + resolved = impl._literal_coercion( + element, argname=argname, **kw + ) + else: + resolved = element + else: + resolved = element + if ( + apply_propagate_attrs is not None + and not apply_propagate_attrs._propagate_attrs + and resolved._propagate_attrs + ): + apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs + + if impl._role_class in resolved.__class__.__mro__: + if impl._post_coercion: + resolved = impl._post_coercion( + resolved, + argname=argname, + original_element=original_element, + **kw + ) + return resolved + else: + return impl._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + + +def expect_as_key(role, element, **kw): + kw["as_key"] = True + return expect(role, element, **kw) + + +def expect_col_expression_collection(role, expressions): + for expr in expressions: + strname = None + column = None + + resolved = expect(role, expr) + if isinstance(resolved, util.string_types): + strname = resolved = expr + else: + cols = [] + visitors.traverse(resolved, {}, {"column": cols.append}) + if cols: + column = cols[0] + add_element = column if column is not None else strname + yield resolved, column, strname, add_element + + +class RoleImpl(object): + __slots__ = ("_role_class", "name", "_use_inspection") + + def _literal_coercion(self, element, **kw): + raise NotImplementedError() + + _post_coercion = None + _resolve_literal_only = False + _skip_clauseelement_for_target_match = False + + def __init__(self, role_class): + self._role_class = role_class + self.name = role_class._role_name + self._use_inspection = issubclass(role_class, roles.UsesInspection) + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + self._raise_for_expected(element, argname, resolved) + + def _raise_for_expected( + self, + element, + argname=None, + resolved=None, + advice=None, + code=None, + err=None, + ): + if resolved is not None and resolved is not element: + got = "%r object resolved from %r object" % (resolved, element) + else: + got = repr(element) + + if argname: + msg = "%s expected for argument %r; got %s." % ( + self.name, + argname, + got, + ) + else: + msg = "%s expected, got %s." % (self.name, got) + + if advice: + msg += " " + advice + + util.raise_(exc.ArgumentError(msg, code=code), replace_context=err) + + +class _Deannotate(object): + __slots__ = () + + def _post_coercion(self, resolved, **kw): + from .util import _deep_deannotate + + return _deep_deannotate(resolved) + + +class _StringOnly(object): + __slots__ = () + + _resolve_literal_only = True + + +class _ReturnsStringKey(object): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, util.string_types): + return original_element + else: + self._raise_for_expected(original_element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class _ColumnCoercions(object): + __slots__ = () + + def _warn_for_scalar_subquery_coercion(self): + util.warn( + "implicitly coercing SELECT object to scalar subquery; " + "please use the .scalar_subquery() method to produce a scalar " + "subquery.", + ) + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if not getattr(resolved, "is_clause_element", False): + self._raise_for_expected(original_element, argname, resolved) + elif resolved._is_select_statement: + self._warn_for_scalar_subquery_coercion() + return resolved.scalar_subquery() + elif resolved._is_from_clause and isinstance( + resolved, selectable.Subquery + ): + self._warn_for_scalar_subquery_coercion() + return resolved.element.scalar_subquery() + elif self._role_class.allows_lambda and resolved._is_lambda_element: + return resolved + else: + self._raise_for_expected(original_element, argname, resolved) + + +def _no_text_coercion( + element, argname=None, exc_cls=exc.ArgumentError, extra=None, err=None +): + util.raise_( + exc_cls( + "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "extra": "%s " % extra if extra else "", + } + ), + replace_context=err, + ) + + +class _NoTextCoercion(object): + __slots__ = () + + def _literal_coercion(self, element, argname=None, **kw): + if isinstance(element, util.string_types) and issubclass( + elements.TextClause, self._role_class + ): + _no_text_coercion(element, argname) + else: + self._raise_for_expected(element, argname) + + +class _CoerceLiterals(object): + __slots__ = () + _coerce_consts = False + _coerce_star = False + _coerce_numerics = False + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + def _literal_coercion(self, element, argname=None, **kw): + if isinstance(element, util.string_types): + if self._coerce_star and element == "*": + return elements.ColumnClause("*", is_literal=True) + else: + return self._text_coercion(element, argname, **kw) + + if self._coerce_consts: + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + + if self._coerce_numerics and isinstance(element, (numbers.Number)): + return elements.ColumnClause(str(element), is_literal=True) + + self._raise_for_expected(element, argname) + + +class LiteralValueImpl(RoleImpl): + _resolve_literal_only = True + + def _implicit_coercions( + self, element, resolved, argname, type_=None, **kw + ): + if not _is_literal(resolved): + self._raise_for_expected( + element, resolved=resolved, argname=argname, **kw + ) + + return elements.BindParameter(None, element, type_=type_, unique=True) + + def _literal_coercion(self, element, argname=None, type_=None, **kw): + return element + + +class _SelectIsNotFrom(object): + __slots__ = () + + def _raise_for_expected(self, element, argname=None, resolved=None, **kw): + if isinstance(element, roles.SelectStatementRole) or isinstance( + resolved, roles.SelectStatementRole + ): + advice = ( + "To create a " + "FROM clause from a %s object, use the .subquery() method." + % (resolved.__class__ if resolved is not None else element,) + ) + code = "89ve" + else: + advice = code = None + + return super(_SelectIsNotFrom, self)._raise_for_expected( + element, + argname=argname, + resolved=resolved, + advice=advice, + code=code, + **kw + ) + + +class HasCacheKeyImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, traversals.HasCacheKey): + return original_element + else: + self._raise_for_expected(original_element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class ExecutableOptionImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, ExecutableOption): + return original_element + else: + self._raise_for_expected(original_element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class ExpressionElementImpl(_ColumnCoercions, RoleImpl): + __slots__ = () + + def _literal_coercion( + self, element, name=None, type_=None, argname=None, is_crud=False, **kw + ): + if ( + element is None + and not is_crud + and (type_ is None or not type_.should_evaluate_none) + ): + # TODO: there's no test coverage now for the + # "should_evaluate_none" part of this, as outside of "crud" this + # codepath is not normally used except in some special cases + return elements.Null() + else: + try: + return elements.BindParameter( + name, element, type_, unique=True, _is_crud=is_crud + ) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) + + def _raise_for_expected(self, element, argname=None, resolved=None, **kw): + if isinstance(element, roles.AnonymizedFromClauseRole): + advice = ( + "To create a " + "column expression from a FROM clause row " + "as a whole, use the .table_valued() method." + ) + else: + advice = None + + return super(ExpressionElementImpl, self)._raise_for_expected( + element, argname=argname, resolved=resolved, advice=advice, **kw + ) + + +class BinaryElementImpl(ExpressionElementImpl, RoleImpl): + + __slots__ = () + + def _literal_coercion( + self, element, expr, operator, bindparam_type=None, argname=None, **kw + ): + try: + return expr._bind_param(operator, element, type_=bindparam_type) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) + + def _post_coercion(self, resolved, expr, bindparam_type=None, **kw): + if resolved.type._isnull and not expr.type._isnull: + resolved = resolved._with_binary_element_type( + bindparam_type if bindparam_type is not None else expr.type + ) + return resolved + + +class InElementImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.element._is_select_statement + ): + self._warn_for_implicit_coercion(resolved) + return self._post_coercion(resolved.element, **kw) + else: + self._warn_for_implicit_coercion(resolved) + return self._post_coercion(resolved.select(), **kw) + else: + self._raise_for_expected(original_element, argname, resolved) + + def _warn_for_implicit_coercion(self, elem): + util.warn( + "Coercing %s object into a select() for use in IN(); " + "please pass a select() construct explicitly" + % (elem.__class__.__name__) + ) + + def _literal_coercion(self, element, expr, operator, **kw): + if isinstance(element, collections_abc.Iterable) and not isinstance( + element, util.string_types + ): + non_literal_expressions = {} + element = list(element) + for o in element: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + self._raise_for_expected(element, **kw) + else: + non_literal_expressions[o] = o + elif o is None: + non_literal_expressions[o] = elements.Null() + + if non_literal_expressions: + return elements.ClauseList( + *[ + non_literal_expressions[o] + if o in non_literal_expressions + else expr._bind_param(operator, o) + for o in element + ] + ) + else: + return expr._bind_param(operator, element, expanding=True) + + else: + self._raise_for_expected(element, **kw) + + def _post_coercion(self, element, expr, operator, **kw): + if element._is_select_statement: + # for IN, we are doing scalar_subquery() coercion without + # a warning + return element.scalar_subquery() + elif isinstance(element, elements.ClauseList): + assert not len(element.clauses) == 0 + return element.self_group(against=operator) + + elif isinstance(element, elements.BindParameter): + element = element._clone(maintain_key=True) + element.expanding = True + element.expand_op = operator + + return element + else: + return element + + +class OnClauseImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _implicit_coercions( + self, original_element, resolved, argname=None, legacy=False, **kw + ): + if legacy and isinstance(resolved, str): + return resolved + else: + return super(OnClauseImpl, self)._implicit_coercions( + original_element, + resolved, + argname=argname, + legacy=legacy, + **kw + ) + + def _text_coercion(self, element, argname=None, legacy=False): + if legacy and isinstance(element, str): + util.warn_deprecated_20( + "Using strings to indicate relationship names in " + "Query.join() is deprecated and will be removed in " + "SQLAlchemy 2.0. Please use the class-bound attribute " + "directly." + ) + return element + + return super(OnClauseImpl, self)._text_coercion(element, argname) + + def _post_coercion(self, resolved, original_element=None, **kw): + # this is a hack right now as we want to use coercion on an + # ORM InstrumentedAttribute, but we want to return the object + # itself if it is one, not its clause element. + # ORM context _join and _legacy_join() would need to be improved + # to look for annotations in a clause element form. + if isinstance(original_element, roles.JoinTargetRole): + return original_element + return resolved + + +class WhereHavingImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + +class StatementOptionImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class ColumnArgumentImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + +class ColumnArgumentOrKeyImpl(_ReturnsStringKey, RoleImpl): + __slots__ = () + + +class StrAsPlainColumnImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + def _text_coercion(self, element, argname=None): + return elements.ColumnClause(element) + + +class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole): + + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements._textual_label_reference(element) + + +class OrderByImpl(ByOfImpl, RoleImpl): + __slots__ = () + + def _post_coercion(self, resolved, **kw): + if ( + isinstance(resolved, self._role_class) + and resolved._order_by_label_element is not None + ): + return elements._label_reference(resolved) + else: + return resolved + + +class GroupByImpl(ByOfImpl, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(resolved, roles.StrictFromClauseRole): + return elements.ClauseList(*resolved.c) + else: + return resolved + + +class DMLColumnImpl(_ReturnsStringKey, RoleImpl): + __slots__ = () + + def _post_coercion(self, element, as_key=False, **kw): + if as_key: + return element.key + else: + return element + + +class ConstExprImpl(RoleImpl): + __slots__ = () + + def _literal_coercion(self, element, argname=None, **kw): + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + else: + self._raise_for_expected(element, argname) + + +class TruncatedLabelImpl(_StringOnly, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(original_element, util.string_types): + return resolved + else: + self._raise_for_expected(original_element, argname, resolved) + + def _literal_coercion(self, element, argname=None, **kw): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(element, elements._truncated_label): + return element + else: + return elements._truncated_label(element) + + +class DDLExpressionImpl(_Deannotate, _CoerceLiterals, RoleImpl): + + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + # see #5754 for why we can't easily deprecate this coercion. + # essentially expressions like postgresql_where would have to be + # text() as they come back from reflection and we don't want to + # have text() elements wired into the inspection dictionaries. + return elements.TextClause(element) + + +class DDLConstraintColumnImpl(_Deannotate, _ReturnsStringKey, RoleImpl): + __slots__ = () + + +class DDLReferredColumnImpl(DDLConstraintColumnImpl): + __slots__ = () + + +class LimitOffsetImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + if resolved is None: + return None + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, name, type_, **kw): + if element is None: + return None + else: + value = util.asint(element) + return selectable._OffsetLimitParam( + name, value, type_=type_, unique=True + ) + + +class LabeledColumnExprImpl(ExpressionElementImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if isinstance(resolved, roles.ExpressionElementRole): + return resolved.label(None) + else: + new = super(LabeledColumnExprImpl, self)._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + if isinstance(new, roles.ExpressionElementRole): + return new.label(None) + else: + self._raise_for_expected(original_element, argname, resolved) + + +class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): + __slots__ = () + + _coerce_consts = True + _coerce_numerics = True + _coerce_star = True + + _guess_straight_column = re.compile(r"^\w\S*$", re.I) + + def _text_coercion(self, element, argname=None): + element = str(element) + + guess_is_literal = not self._guess_straight_column.match(element) + raise exc.ArgumentError( + "Textual column expression %(column)r %(argname)sshould be " + "explicitly declared with text(%(column)r), " + "or use %(literal_column)s(%(column)r) " + "for more specificity" + % { + "column": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "literal_column": "literal_column" + if guess_is_literal + else "column", + } + ) + + +class ReturnsRowsImpl(RoleImpl): + __slots__ = () + + +class StatementImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + def _post_coercion(self, resolved, original_element, argname=None, **kw): + if resolved is not original_element and not isinstance( + original_element, util.string_types + ): + # use same method as Connection uses; this will later raise + # ObjectNotExecutableError + try: + original_element._execute_on_connection + except AttributeError: + util.warn_deprecated( + "Object %r should not be used directly in a SQL statement " + "context, such as passing to methods such as " + "session.execute(). This usage will be disallowed in a " + "future release. " + "Please use Core select() / update() / delete() etc. " + "with Session.execute() and other statement execution " + "methods." % original_element, + "1.4", + ) + + return resolved + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_lambda_element: + return resolved + else: + return super(StatementImpl, self)._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + + def _text_coercion(self, element, argname=None): + util.warn_deprecated_20( + "Using plain strings to indicate SQL statements without using " + "the text() construct is " + "deprecated and will be removed in version 2.0. Ensure plain " + "SQL statements are passed using the text() construct." + ) + return elements.TextClause(element) + + +class SelectStatementImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_text_clause: + return resolved.columns() + else: + self._raise_for_expected(original_element, argname, resolved) + + +class HasCTEImpl(ReturnsRowsImpl): + __slots__ = () + + +class IsCTEImpl(RoleImpl): + __slots__ = () + + +class JoinTargetImpl(RoleImpl): + __slots__ = () + + _skip_clauseelement_for_target_match = True + + def _literal_coercion(self, element, legacy=False, **kw): + if isinstance(element, str): + return element + + def _implicit_coercions( + self, original_element, resolved, argname=None, legacy=False, **kw + ): + if isinstance(original_element, roles.JoinTargetRole): + # note that this codepath no longer occurs as of + # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match + # were set to False. + return original_element + elif legacy and isinstance(resolved, str): + util.warn_deprecated_20( + "Using strings to indicate relationship names in " + "Query.join() is deprecated and will be removed in " + "SQLAlchemy 2.0. Please use the class-bound attribute " + "directly." + ) + return resolved + elif legacy and isinstance(resolved, roles.WhereHavingRole): + return resolved + elif legacy and resolved._is_select_statement: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT " + "constructs into FROM clauses is deprecated; please call " + ".subquery() on any Core select or ORM Query object in " + "order to produce a subquery object.", + version="1.4", + ) + # TODO: doing _implicit_subquery here causes tests to fail, + # how was this working before? probably that ORM + # join logic treated it as a select and subquery would happen + # in _ORMJoin->Join + return resolved + else: + self._raise_for_expected(original_element, argname, resolved) + + +class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + original_element, + resolved, + argname=None, + explicit_subquery=False, + allow_select=True, + **kw + ): + if resolved._is_select_statement: + if explicit_subquery: + return resolved.subquery() + elif allow_select: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT " + "constructs into FROM clauses is deprecated; please call " + ".subquery() on any Core select or ORM Query object in " + "order to produce a subquery object.", + version="1.4", + ) + return resolved._implicit_subquery + elif resolved._is_text_clause: + return resolved + else: + self._raise_for_expected(original_element, argname, resolved) + + def _post_coercion(self, element, deannotate=False, **kw): + if deannotate: + return element._deannotate() + else: + return element + + +class StrictFromClauseImpl(FromClauseImpl): + __slots__ = () + + def _implicit_coercions( + self, + original_element, + resolved, + argname=None, + allow_select=False, + **kw + ): + if resolved._is_select_statement and allow_select: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT constructs " + "into FROM clauses is deprecated; please call .subquery() " + "on any Core select or ORM Query object in order to produce a " + "subquery object.", + version="1.4", + ) + return resolved._implicit_subquery + else: + self._raise_for_expected(original_element, argname, resolved) + + +class AnonymizedFromClauseImpl(StrictFromClauseImpl): + __slots__ = () + + def _post_coercion(self, element, flat=False, name=None, **kw): + assert name is None + + return element._anonymous_fromclause(flat=flat) + + +class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): + __slots__ = () + + def _post_coercion(self, element, **kw): + if "dml_table" in element._annotations: + return element._annotations["dml_table"] + else: + return element + + +class DMLSelectImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, original_element, resolved, argname=None, **kw + ): + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.element._is_select_statement + ): + return resolved.element + else: + return resolved.select() + else: + self._raise_for_expected(original_element, argname, resolved) + + +class CompoundElementImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _raise_for_expected(self, element, argname=None, resolved=None, **kw): + if isinstance(element, roles.FromClauseRole): + if element._is_subquery: + advice = ( + "Use the plain select() object without " + "calling .subquery() or .alias()." + ) + else: + advice = ( + "To SELECT from any FROM clause, use the .select() method." + ) + else: + advice = None + return super(CompoundElementImpl, self)._raise_for_expected( + element, argname=argname, resolved=resolved, advice=advice, **kw + ) + + +_impl_lookup = {} + + +for name in dir(roles): + cls = getattr(roles, name) + if name.endswith("Role"): + name = name.replace("Role", "Impl") + if name in globals(): + impl = globals()[name](cls) + _impl_lookup[cls] = impl diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py new file mode 100644 index 0000000..c9b6ba6 --- /dev/null +++ b/lib/sqlalchemy/sql/compiler.py @@ -0,0 +1,5525 @@ +# sql/compiler.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Base SQL and DDL compiler implementations. + +Classes provided include: + +:class:`.compiler.SQLCompiler` - renders SQL +strings + +:class:`.compiler.DDLCompiler` - renders DDL +(data definition language) strings + +:class:`.compiler.GenericTypeCompiler` - renders +type specification strings. + +To generate user-defined SQL strings, see +:doc:`/ext/compiler`. + +""" + +import collections +import contextlib +import itertools +import operator +import re + +from . import base +from . import coercions +from . import crud +from . import elements +from . import functions +from . import operators +from . import schema +from . import selectable +from . import sqltypes +from .base import NO_ARG +from .base import prefix_anon_map +from .elements import quoted_name +from .. import exc +from .. import util + +RESERVED_WORDS = set( + [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", + ] +) + +LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) +LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I) +ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"]) + +FK_ON_DELETE = re.compile( + r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I +) +FK_ON_UPDATE = re.compile( + r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I +) +FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I) +BIND_PARAMS = re.compile(r"(? ", + operators.ge: " >= ", + operators.eq: " = ", + operators.is_distinct_from: " IS DISTINCT FROM ", + operators.is_not_distinct_from: " IS NOT DISTINCT FROM ", + operators.concat_op: " || ", + operators.match_op: " MATCH ", + operators.not_match_op: " NOT MATCH ", + operators.in_op: " IN ", + operators.not_in_op: " NOT IN ", + operators.comma_op: ", ", + operators.from_: " FROM ", + operators.as_: " AS ", + operators.is_: " IS ", + operators.is_not: " IS NOT ", + operators.collate: " COLLATE ", + # unary + operators.exists: "EXISTS ", + operators.distinct_op: "DISTINCT ", + operators.inv: "NOT ", + operators.any_op: "ANY ", + operators.all_op: "ALL ", + # modifiers + operators.desc_op: " DESC", + operators.asc_op: " ASC", + operators.nulls_first_op: " NULLS FIRST", + operators.nulls_last_op: " NULLS LAST", +} + +FUNCTIONS = { + functions.coalesce: "coalesce", + functions.current_date: "CURRENT_DATE", + functions.current_time: "CURRENT_TIME", + functions.current_timestamp: "CURRENT_TIMESTAMP", + functions.current_user: "CURRENT_USER", + functions.localtime: "LOCALTIME", + functions.localtimestamp: "LOCALTIMESTAMP", + functions.random: "random", + functions.sysdate: "sysdate", + functions.session_user: "SESSION_USER", + functions.user: "USER", + functions.cube: "CUBE", + functions.rollup: "ROLLUP", + functions.grouping_sets: "GROUPING SETS", +} + +EXTRACT_MAP = { + "month": "month", + "day": "day", + "year": "year", + "second": "second", + "hour": "hour", + "doy": "doy", + "minute": "minute", + "quarter": "quarter", + "dow": "dow", + "week": "week", + "epoch": "epoch", + "milliseconds": "milliseconds", + "microseconds": "microseconds", + "timezone_hour": "timezone_hour", + "timezone_minute": "timezone_minute", +} + +COMPOUND_KEYWORDS = { + selectable.CompoundSelect.UNION: "UNION", + selectable.CompoundSelect.UNION_ALL: "UNION ALL", + selectable.CompoundSelect.EXCEPT: "EXCEPT", + selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL", + selectable.CompoundSelect.INTERSECT: "INTERSECT", + selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL", +} + + +RM_RENDERED_NAME = 0 +RM_NAME = 1 +RM_OBJECTS = 2 +RM_TYPE = 3 + + +ExpandedState = collections.namedtuple( + "ExpandedState", + [ + "statement", + "additional_parameters", + "processors", + "positiontup", + "parameter_expansion", + ], +) + + +NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0) + +COLLECT_CARTESIAN_PRODUCTS = util.symbol( + "COLLECT_CARTESIAN_PRODUCTS", + "Collect data on FROMs and cartesian products and gather " + "into 'self.from_linter'", + canonical=1, +) + +WARN_LINTING = util.symbol( + "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2 +) + +FROM_LINTING = util.symbol( + "FROM_LINTING", + "Warn for cartesian products; " + "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING", + canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING, +) + + +class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): + def lint(self, start=None): + froms = self.froms + if not froms: + return None, None + + edges = set(self.edges) + the_rest = set(froms) + + if start is not None: + start_with = start + the_rest.remove(start_with) + else: + start_with = the_rest.pop() + + stack = collections.deque([start_with]) + + while stack and the_rest: + node = stack.popleft() + the_rest.discard(node) + + # comparison of nodes in edges here is based on hash equality, as + # there are "annotated" elements that match the non-annotated ones. + # to remove the need for in-python hash() calls, use native + # containment routines (e.g. "node in edge", "edge.index(node)") + to_remove = {edge for edge in edges if node in edge} + + # appendleft the node in each edge that is not + # the one that matched. + stack.extendleft(edge[not edge.index(node)] for edge in to_remove) + edges.difference_update(to_remove) + + # FROMS left over? boom + if the_rest: + return the_rest, start_with + else: + return None, None + + def warn(self): + the_rest, start_with = self.lint() + + # FROMS left over? boom + if the_rest: + + froms = the_rest + if froms: + template = ( + "SELECT statement has a cartesian product between " + "FROM element(s) {froms} and " + 'FROM element "{start}". Apply join condition(s) ' + "between each element to resolve." + ) + froms_str = ", ".join( + '"{elem}"'.format(elem=self.froms[from_]) + for from_ in froms + ) + message = template.format( + froms=froms_str, start=self.froms[start_with] + ) + + util.warn(message) + + +class Compiled(object): + + """Represent a compiled SQL or DDL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + _cached_metadata = None + + _result_columns = None + + schema_translate_map = None + + execution_options = util.EMPTY_DICT + """ + Execution options propagated from the statement. In some cases, + sub-elements of the statement can modify these. + """ + + _annotations = util.EMPTY_DICT + + compile_state = None + """Optional :class:`.CompileState` object that maintains additional + state used by the compiler. + + Major executable objects such as :class:`_expression.Insert`, + :class:`_expression.Update`, :class:`_expression.Delete`, + :class:`_expression.Select` will generate this + state when compiled in order to calculate additional information about the + object. For the top level object that is to be executed, the state can be + stored here where it can also have applicability towards result set + processing. + + .. versionadded:: 1.4 + + """ + + dml_compile_state = None + """Optional :class:`.CompileState` assigned at the same point that + .isinsert, .isupdate, or .isdelete is assigned. + + This will normally be the same object as .compile_state, with the + exception of cases like the :class:`.ORMFromStatementCompileState` + object. + + .. versionadded:: 1.4.40 + + """ + + cache_key = None + _gen_time = None + + def __init__( + self, + dialect, + statement, + schema_translate_map=None, + render_schema_translate=False, + compile_kwargs=util.immutabledict(), + ): + """Construct a new :class:`.Compiled` object. + + :param dialect: :class:`.Dialect` to compile against. + + :param statement: :class:`_expression.ClauseElement` to be compiled. + + :param schema_translate_map: dictionary of schema names to be + translated when forming the resultant SQL + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`schema_translating` + + :param compile_kwargs: additional kwargs that will be + passed to the initial call to :meth:`.Compiled.process`. + + + """ + + self.dialect = dialect + self.preparer = self.dialect.identifier_preparer + if schema_translate_map: + self.schema_translate_map = schema_translate_map + self.preparer = self.preparer._with_schema_translate( + schema_translate_map + ) + + if statement is not None: + self.statement = statement + self.can_execute = statement.supports_execution + self._annotations = statement._annotations + if self.can_execute: + self.execution_options = statement._execution_options + self.string = self.process(self.statement, **compile_kwargs) + + if render_schema_translate: + self.string = self.preparer._render_schema_translates( + self.string, schema_translate_map + ) + self._gen_time = util.perf_counter() + + def _execute_on_connection( + self, connection, multiparams, params, execution_options + ): + if self.can_execute: + return connection._execute_compiled( + self, multiparams, params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self.statement) + + def visit_unsupported_compilation(self, element, err): + util.raise_( + exc.UnsupportedCompilationError(self, type(element)), + replace_context=err, + ) + + @property + def sql_compiler(self): + """Return a Compiled that is capable of processing SQL expressions. + + If this compiler is one, it would likely just return 'self'. + + """ + + raise NotImplementedError() + + def process(self, obj, **kwargs): + return obj._compiler_dispatch(self, **kwargs) + + def __str__(self): + """Return the string text of the generated SQL or DDL.""" + + return self.string or "" + + def construct_params( + self, params=None, extracted_parameters=None, escape_names=True + ): + """Return the bind params for this compiled object. + + :param params: a dict of string/object pairs whose values will + override bind values compiled in to the + statement. + """ + + raise NotImplementedError() + + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + + +class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)): + """Produces DDL specification for TypeEngine objects.""" + + ensure_kwarg = r"visit_\w+" + + def __init__(self, dialect): + self.dialect = dialect + + def process(self, type_, **kw): + return type_._compiler_dispatch(self, **kw) + + def visit_unsupported_compilation(self, element, err, **kw): + util.raise_( + exc.UnsupportedCompilationError(self, element), + replace_context=err, + ) + + +# this was a Visitable, but to allow accurate detection of +# column elements this is actually a column element +class _CompileLabel(elements.ColumnElement): + + """lightweight label object which acts as an expression.Label.""" + + __visit_name__ = "label" + __slots__ = "element", "name" + + def __init__(self, col, name, alt_names=()): + self.element = col + self.name = name + self._alt_names = (col,) + alt_names + + @property + def proxy_set(self): + return self.element.proxy_set + + @property + def type(self): + return self.element.type + + def self_group(self, **kw): + return self + + +class SQLCompiler(Compiled): + """Default implementation of :class:`.Compiled`. + + Compiles :class:`_expression.ClauseElement` objects into SQL strings. + + """ + + extract_map = EXTRACT_MAP + + compound_keywords = COMPOUND_KEYWORDS + + isdelete = isinsert = isupdate = False + """class-level defaults which can be set at the instance + level to define if this Compiled instance represents + INSERT/UPDATE/DELETE + """ + + isplaintext = False + + returning = None + """holds the "returning" collection of columns if + the statement is CRUD and defines returning columns + either implicitly or explicitly + """ + + returning_precedes_values = False + """set to True classwide to generate RETURNING + clauses before the VALUES or WHERE clause (i.e. MSSQL) + """ + + render_table_with_column_in_update_from = False + """set to True classwide to indicate the SET clause + in a multi-table UPDATE statement should qualify + columns with the table name (i.e. MySQL only) + """ + + ansi_bind_rules = False + """SQL 92 doesn't allow bind parameters to be used + in the columns clause of a SELECT, nor does it allow + ambiguous expressions like "? = ?". A compiler + subclass can set this flag to False if the target + driver/DB enforces this + """ + + _textual_ordered_columns = False + """tell the result object that the column names as rendered are important, + but they are also "ordered" vs. what is in the compiled object here. + """ + + _ordered_columns = True + """ + if False, means we can't be sure the list of entries + in _result_columns is actually the rendered order. Usually + True unless using an unordered TextualSelect. + """ + + _loose_column_name_matching = False + """tell the result object that the SQL statement is textual, wants to match + up to Column objects, and may be using the ._tq_label in the SELECT rather + than the base name. + + """ + + _numeric_binds = False + """ + True if paramstyle is "numeric". This paramstyle is trickier than + all the others. + + """ + + _render_postcompile = False + """ + whether to render out POSTCOMPILE params during the compile phase. + + """ + + insert_single_values_expr = None + """When an INSERT is compiled with a single set of parameters inside + a VALUES expression, the string is assigned here, where it can be + used for insert batching schemes to rewrite the VALUES expression. + + .. versionadded:: 1.3.8 + + """ + + literal_execute_params = frozenset() + """bindparameter objects that are rendered as literal values at statement + execution time. + + """ + + post_compile_params = frozenset() + """bindparameter objects that are rendered as bound parameter placeholders + at statement execution time. + + """ + + escaped_bind_names = util.EMPTY_DICT + """Late escaping of bound parameter names that has to be converted + to the original name when looking in the parameter dictionary. + + """ + + has_out_parameters = False + """if True, there are bindparam() objects that have the isoutparam + flag set.""" + + insert_prefetch = update_prefetch = () + + postfetch_lastrowid = False + """if True, and this in insert, use cursor.lastrowid to populate + result.inserted_primary_key. """ + + _cache_key_bind_match = None + """a mapping that will relate the BindParameter object we compile + to those that are part of the extracted collection of parameters + in the cache key, if we were given a cache key. + + """ + + positiontup = None + """for a compiled construct that uses a positional paramstyle, will be + a sequence of strings, indicating the names of bound parameters in order. + + This is used in order to render bound parameters in their correct order, + and is combined with the :attr:`_sql.Compiled.params` dictionary to + render parameters. + + .. seealso:: + + :ref:`faq_sql_expression_string` - includes a usage example for + debugging use cases. + + """ + + inline = False + + def __init__( + self, + dialect, + statement, + cache_key=None, + column_keys=None, + for_executemany=False, + linting=NO_LINTING, + **kwargs + ): + """Construct a new :class:`.SQLCompiler` object. + + :param dialect: :class:`.Dialect` to be used + + :param statement: :class:`_expression.ClauseElement` to be compiled + + :param column_keys: a list of column names to be compiled into an + INSERT or UPDATE statement. + + :param for_executemany: whether INSERT / UPDATE statements should + expect that they are to be invoked in an "executemany" style, + which may impact how the statement will be expected to return the + values of defaults and autoincrement / sequences and similar. + Depending on the backend and driver in use, support for retrieving + these values may be disabled which means SQL expressions may + be rendered inline, RETURNING may not be rendered, etc. + + :param kwargs: additional keyword arguments to be consumed by the + superclass. + + """ + self.column_keys = column_keys + + self.cache_key = cache_key + + if cache_key: + self._cache_key_bind_match = ckbm = { + b.key: b for b in cache_key[1] + } + ckbm.update({b: [b] for b in cache_key[1]}) + + # compile INSERT/UPDATE defaults/sequences to expect executemany + # style execution, which may mean no pre-execute of defaults, + # or no RETURNING + self.for_executemany = for_executemany + + self.linting = linting + + # a dictionary of bind parameter keys to BindParameter + # instances. + self.binds = {} + + # a dictionary of BindParameter instances to "compiled" names + # that are actually present in the generated SQL + self.bind_names = util.column_dict() + + # stack which keeps track of nested SELECT statements + self.stack = [] + + # relates label names in the final SQL to a tuple of local + # column/label name, ColumnElement object (if any) and + # TypeEngine. CursorResult uses this for type processing and + # column targeting + self._result_columns = [] + + # true if the paramstyle is positional + self.positional = dialect.positional + if self.positional: + self.positiontup = [] + self._numeric_binds = dialect.paramstyle == "numeric" + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + self.ctes = None + + self.label_length = ( + dialect.label_length or dialect.max_identifier_length + ) + + # a map which tracks "anonymous" identifiers that are created on + # the fly here + self.anon_map = prefix_anon_map() + + # a map which tracks "truncated" names based on + # dialect.label_length or dialect.max_identifier_length + self.truncated_names = {} + + Compiled.__init__(self, dialect, statement, **kwargs) + + if self.isinsert or self.isupdate or self.isdelete: + if statement._returning: + self.returning = statement._returning + + if self.isinsert or self.isupdate: + if statement._inline: + self.inline = True + elif self.for_executemany and ( + not self.isinsert + or ( + self.dialect.insert_executemany_returning + and statement._return_defaults + ) + ): + self.inline = True + + if self.positional and self._numeric_binds: + self._apply_numbered_params() + + if self._render_postcompile: + self._process_parameters_for_postcompile(_populate_self=True) + + @property + def current_executable(self): + """Return the current 'executable' that is being compiled. + + This is currently the :class:`_sql.Select`, :class:`_sql.Insert`, + :class:`_sql.Update`, :class:`_sql.Delete`, + :class:`_sql.CompoundSelect` object that is being compiled. + Specifically it's assigned to the ``self.stack`` list of elements. + + When a statement like the above is being compiled, it normally + is also assigned to the ``.statement`` attribute of the + :class:`_sql.Compiler` object. However, all SQL constructs are + ultimately nestable, and this attribute should never be consulted + by a ``visit_`` method, as it is not guaranteed to be assigned + nor guaranteed to correspond to the current statement being compiled. + + .. versionadded:: 1.3.21 + + For compatibility with previous versions, use the following + recipe:: + + statement = getattr(self, "current_executable", False) + if statement is False: + statement = self.stack[-1]["selectable"] + + For versions 1.4 and above, ensure only .current_executable + is used; the format of "self.stack" may change. + + + """ + try: + return self.stack[-1]["selectable"] + except IndexError as ie: + util.raise_( + IndexError("Compiler does not have a stack entry"), + replace_context=ie, + ) + + @property + def prefetch(self): + return list(self.insert_prefetch + self.update_prefetch) + + @util.memoized_property + def _global_attributes(self): + return {} + + @util.memoized_instancemethod + def _init_cte_state(self): + """Initialize collections related to CTEs only if + a CTE is located, to save on the overhead of + these collections otherwise. + + """ + # collect CTEs to tack on top of a SELECT + # To store the query to print - Dict[cte, text_query] + self.ctes = util.OrderedDict() + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + self.ctes_by_level_name = {} + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name)] + self.level_name_by_cte = {} + + self.ctes_recursive = False + if self.positional: + self.cte_positional = {} + + @contextlib.contextmanager + def _nested_result(self): + """special API to support the use case of 'nested result sets'""" + result_columns, ordered_columns = ( + self._result_columns, + self._ordered_columns, + ) + self._result_columns, self._ordered_columns = [], False + + try: + if self.stack: + entry = self.stack[-1] + entry["need_result_map_for_nested"] = True + else: + entry = None + yield self._result_columns, self._ordered_columns + finally: + if entry: + entry.pop("need_result_map_for_nested") + self._result_columns, self._ordered_columns = ( + result_columns, + ordered_columns, + ) + + def _apply_numbered_params(self): + poscount = itertools.count(1) + self.string = re.sub( + r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string + ) + + @util.memoized_property + def _bind_processors(self): + + return dict( + ( + key, + value, + ) + for key, value in ( + ( + self.bind_names[bindparam], + bindparam.type._cached_bind_processor(self.dialect) + if not bindparam.type._is_tuple_type + else tuple( + elem_type._cached_bind_processor(self.dialect) + for elem_type in bindparam.type.types + ), + ) + for bindparam in self.bind_names + ) + if value is not None + ) + + def is_subquery(self): + return len(self.stack) > 1 + + @property + def sql_compiler(self): + return self + + def construct_params( + self, + params=None, + _group_number=None, + _check=True, + extracted_parameters=None, + escape_names=True, + ): + """return a dictionary of bind parameter keys and values""" + + has_escaped_names = escape_names and bool(self.escaped_bind_names) + + if extracted_parameters: + # related the bound parameters collected in the original cache key + # to those collected in the incoming cache key. They will not have + # matching names but they will line up positionally in the same + # way. The parameters present in self.bind_names may be clones of + # these original cache key params in the case of DML but the .key + # will be guaranteed to match. + try: + orig_extracted = self.cache_key[1] + except TypeError as err: + util.raise_( + exc.CompileError( + "This compiled object has no original cache key; " + "can't pass extracted_parameters to construct_params" + ), + replace_context=err, + ) + + ckbm = self._cache_key_bind_match + resolved_extracted = { + bind: extracted + for b, extracted in zip(orig_extracted, extracted_parameters) + for bind in ckbm[b] + } + else: + resolved_extracted = None + + if params: + pd = {} + for bindparam, name in self.bind_names.items(): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if has_escaped_names + else name + ) + + if bindparam.key in params: + pd[escaped_name] = params[bindparam.key] + elif name in params: + pd[escaped_name] = params[name] + + elif _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key, + code="cd3x", + ) + else: + if resolved_extracted: + value_param = resolved_extracted.get( + bindparam, bindparam + ) + else: + value_param = bindparam + + if bindparam.callable: + pd[escaped_name] = value_param.effective_value + else: + pd[escaped_name] = value_param.value + return pd + else: + pd = {} + for bindparam, name in self.bind_names.items(): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if has_escaped_names + else name + ) + + if _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key, + code="cd3x", + ) + + if resolved_extracted: + value_param = resolved_extracted.get(bindparam, bindparam) + else: + value_param = bindparam + + if bindparam.callable: + pd[escaped_name] = value_param.effective_value + else: + pd[escaped_name] = value_param.value + return pd + + @util.memoized_instancemethod + def _get_set_input_sizes_lookup( + self, include_types=None, exclude_types=None + ): + if not hasattr(self, "bind_names"): + return None + + dialect = self.dialect + dbapi = self.dialect.dbapi + + # _unwrapped_dialect_impl() is necessary so that we get the + # correct dialect type for a custom TypeDecorator, or a Variant, + # which is also a TypeDecorator. Special types like Interval, + # that use TypeDecorator but also might be mapped directly + # for a dialect impl, also subclass Emulated first which overrides + # this behavior in those cases to behave like the default. + + if include_types is None and exclude_types is None: + + def _lookup_type(typ): + dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi) + return dbtype + + else: + + def _lookup_type(typ): + # note we get dbtype from the possibly TypeDecorator-wrapped + # dialect_impl, but the dialect_impl itself that we use for + # include/exclude is the unwrapped version. + + dialect_impl = typ._unwrapped_dialect_impl(dialect) + + dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi) + + if ( + dbtype is not None + and ( + exclude_types is None + or dbtype not in exclude_types + and type(dialect_impl) not in exclude_types + ) + and ( + include_types is None + or dbtype in include_types + or type(dialect_impl) in include_types + ) + ): + return dbtype + else: + return None + + inputsizes = {} + literal_execute_params = self.literal_execute_params + + for bindparam in self.bind_names: + if bindparam in literal_execute_params: + continue + + if bindparam.type._is_tuple_type: + inputsizes[bindparam] = [ + _lookup_type(typ) for typ in bindparam.type.types + ] + else: + inputsizes[bindparam] = _lookup_type(bindparam.type) + + return inputsizes + + @property + def params(self): + """Return the bind param dictionary embedded into this + compiled object, for those values that are present. + + .. seealso:: + + :ref:`faq_sql_expression_string` - includes a usage example for + debugging use cases. + + """ + return self.construct_params(_check=False) + + def _process_parameters_for_postcompile( + self, parameters=None, _populate_self=False + ): + """handle special post compile parameters. + + These include: + + * "expanding" parameters -typically IN tuples that are rendered + on a per-parameter basis for an otherwise fixed SQL statement string. + + * literal_binds compiled with the literal_execute flag. Used for + things like SQL Server "TOP N" where the driver does not accommodate + N as a bound parameter. + + """ + + if parameters is None: + parameters = self.construct_params(escape_names=False) + + expanded_parameters = {} + if self.positional: + positiontup = [] + else: + positiontup = None + + processors = self._bind_processors + + new_processors = {} + + if self.positional and self._numeric_binds: + # I'm not familiar with any DBAPI that uses 'numeric'. + # strategy would likely be to make use of numbers greater than + # the highest number present; then for expanding parameters, + # append them to the end of the parameter list. that way + # we avoid having to renumber all the existing parameters. + raise NotImplementedError( + "'post-compile' bind parameters are not supported with " + "the 'numeric' paramstyle at this time." + ) + + replacement_expressions = {} + to_update_sets = {} + + # notes: + # *unescaped* parameter names in: + # self.bind_names, self.binds, self._bind_processors + # + # *escaped* parameter names in: + # construct_params(), replacement_expressions + + for name in ( + self.positiontup if self.positional else self.bind_names.values() + ): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if self.escaped_bind_names + else name + ) + + parameter = self.binds[name] + if parameter in self.literal_execute_params: + if escaped_name not in replacement_expressions: + value = parameters.pop(name) + + replacement_expressions[ + escaped_name + ] = self.render_literal_bindparam( + parameter, render_literal_value=value + ) + continue + + if parameter in self.post_compile_params: + if escaped_name in replacement_expressions: + to_update = to_update_sets[escaped_name] + else: + # we are removing the parameter from parameters + # because it is a list value, which is not expected by + # TypeEngine objects that would otherwise be asked to + # process it. the single name is being replaced with + # individual numbered parameters for each value in the + # param. + # + # note we are also inserting *escaped* parameter names + # into the given dictionary. default dialect will + # use these param names directly as they will not be + # in the escaped_bind_names dictionary. + values = parameters.pop(name) + + leep = self._literal_execute_expanding_parameter + to_update, replacement_expr = leep( + escaped_name, parameter, values + ) + + to_update_sets[escaped_name] = to_update + replacement_expressions[escaped_name] = replacement_expr + + if not parameter.literal_execute: + parameters.update(to_update) + if parameter.type._is_tuple_type: + new_processors.update( + ( + "%s_%s_%s" % (name, i, j), + processors[name][j - 1], + ) + for i, tuple_element in enumerate(values, 1) + for j, value in enumerate(tuple_element, 1) + if name in processors + and processors[name][j - 1] is not None + ) + else: + new_processors.update( + (key, processors[name]) + for key, value in to_update + if name in processors + ) + if self.positional: + positiontup.extend(name for name, value in to_update) + expanded_parameters[name] = [ + expand_key for expand_key, value in to_update + ] + elif self.positional: + positiontup.append(name) + + def process_expanding(m): + key = m.group(1) + expr = replacement_expressions[key] + + # if POSTCOMPILE included a bind_expression, render that + # around each element + if m.group(2): + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + expr = ", ".join( + "%s%s%s" % (be_left, exp, be_right) + for exp in expr.split(", ") + ) + return expr + + statement = re.sub( + r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]", + process_expanding, + self.string, + ) + + expanded_state = ExpandedState( + statement, + parameters, + new_processors, + positiontup, + expanded_parameters, + ) + + if _populate_self: + # this is for the "render_postcompile" flag, which is not + # otherwise used internally and is for end-user debugging and + # special use cases. + self.string = expanded_state.statement + self._bind_processors.update(expanded_state.processors) + self.positiontup = expanded_state.positiontup + self.post_compile_params = frozenset() + for key in expanded_state.parameter_expansion: + bind = self.binds.pop(key) + self.bind_names.pop(bind) + for value, expanded_key in zip( + bind.value, expanded_state.parameter_expansion[key] + ): + self.binds[expanded_key] = new_param = bind._with_value( + value + ) + self.bind_names[new_param] = expanded_key + + return expanded_state + + @util.preload_module("sqlalchemy.engine.cursor") + def _create_result_map(self): + """utility method used for unit tests only.""" + cursor = util.preloaded.engine_cursor + return cursor.CursorResultMetaData._create_description_match_map( + self._result_columns + ) + + @util.memoized_property + def _within_exec_param_key_getter(self): + getter = self._key_getters_for_crud_column[2] + return getter + + @util.memoized_property + @util.preload_module("sqlalchemy.engine.result") + def _inserted_primary_key_from_lastrowid_getter(self): + result = util.preloaded.engine_result + + param_key_getter = self._within_exec_param_key_getter + table = self.statement.table + + getters = [ + (operator.methodcaller("get", param_key_getter(col), None), col) + for col in table.primary_key + ] + + autoinc_col = table._autoincrement_column + if autoinc_col is not None: + # apply type post processors to the lastrowid + proc = autoinc_col.type._cached_result_processor( + self.dialect, None + ) + else: + proc = None + + row_fn = result.result_tuple([col.key for col in table.primary_key]) + + def get(lastrowid, parameters): + """given cursor.lastrowid value and the parameters used for INSERT, + return a "row" that represents the primary key, either by + using the "lastrowid" or by extracting values from the parameters + that were sent along with the INSERT. + + """ + if proc is not None: + lastrowid = proc(lastrowid) + + if lastrowid is None: + return row_fn(getter(parameters) for getter, col in getters) + else: + return row_fn( + lastrowid if col is autoinc_col else getter(parameters) + for getter, col in getters + ) + + return get + + @util.memoized_property + @util.preload_module("sqlalchemy.engine.result") + def _inserted_primary_key_from_returning_getter(self): + result = util.preloaded.engine_result + + param_key_getter = self._within_exec_param_key_getter + table = self.statement.table + + ret = {col: idx for idx, col in enumerate(self.returning)} + + getters = [ + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, + ) + for col in table.primary_key + ] + + row_fn = result.result_tuple([col.key for col in table.primary_key]) + + def get(row, parameters): + return row_fn( + getter(row) if use_row else getter(parameters) + for getter, use_row in getters + ) + + return get + + def default_from(self): + """Called when a SELECT statement has no froms, and no FROM clause is + to be appended. + + Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. + + """ + return "" + + def visit_grouping(self, grouping, asfrom=False, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + + def visit_select_statement_grouping(self, grouping, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + + def visit_label_reference( + self, element, within_columns_clause=False, **kwargs + ): + if self.stack and self.dialect.supports_simple_order_by_label: + compile_state = self.stack[-1]["compile_state"] + + ( + with_cols, + only_froms, + only_cols, + ) = compile_state._label_resolve_dict + if within_columns_clause: + resolve_dict = only_froms + else: + resolve_dict = only_cols + + # this can be None in the case that a _label_reference() + # were subject to a replacement operation, in which case + # the replacement of the Label element may have changed + # to something else like a ColumnClause expression. + order_by_elem = element.element._order_by_label_element + + if ( + order_by_elem is not None + and order_by_elem.name in resolve_dict + and order_by_elem.shares_lineage( + resolve_dict[order_by_elem.name] + ) + ): + kwargs[ + "render_label_as_label" + ] = element.element._order_by_label_element + return self.process( + element.element, + within_columns_clause=within_columns_clause, + **kwargs + ) + + def visit_textual_label_reference( + self, element, within_columns_clause=False, **kwargs + ): + if not self.stack: + # compiling the element outside of the context of a SELECT + return self.process(element._text_clause) + + compile_state = self.stack[-1]["compile_state"] + with_cols, only_froms, only_cols = compile_state._label_resolve_dict + try: + if within_columns_clause: + col = only_froms[element.element] + else: + col = with_cols[element.element] + except KeyError as err: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=err, + ) + else: + kwargs["render_label_as_label"] = col + return self.process( + col, within_columns_clause=within_columns_clause, **kwargs + ) + + def visit_label( + self, + label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + result_map_targets=(), + **kw + ): + # only render labels within the columns clause + # or ORDER BY clause of a select. dialect-specific compilers + # can modify this behavior. + render_label_with_as = ( + within_columns_clause and not within_label_clause + ) + render_label_only = render_label_as_label is label + + if render_label_only or render_label_with_as: + if isinstance(label.name, elements._truncated_label): + labelname = self._truncated_identifier("colident", label.name) + else: + labelname = label.name + + if render_label_with_as: + if add_to_result_map is not None: + add_to_result_map( + labelname, + label.name, + (label, labelname) + label._alt_names + result_map_targets, + label.type, + ) + return ( + label.element._compiler_dispatch( + self, + within_columns_clause=True, + within_label_clause=True, + **kw + ) + + OPERATORS[operators.as_] + + self.preparer.format_label(label, labelname) + ) + elif render_label_only: + return self.preparer.format_label(label, labelname) + else: + return label.element._compiler_dispatch( + self, within_columns_clause=False, **kw + ) + + def _fallback_column_name(self, column): + raise exc.CompileError( + "Cannot compile Column object until " "its 'name' is assigned." + ) + + def visit_lambda_element(self, element, **kw): + sql_element = element._resolved + return self.process(sql_element, **kw) + + def visit_column( + self, + column, + add_to_result_map=None, + include_table=True, + result_map_targets=(), + **kwargs + ): + name = orig_name = column.name + if name is None: + name = self._fallback_column_name(column) + + is_literal = column.is_literal + if not is_literal and isinstance(name, elements._truncated_label): + name = self._truncated_identifier("colident", name) + + if add_to_result_map is not None: + targets = (column, name, column.key) + result_map_targets + if column._tq_label: + targets += (column._tq_label,) + + add_to_result_map(name, orig_name, targets, column.type) + + if is_literal: + # note we are not currently accommodating for + # literal_column(quoted_name('ident', True)) here + name = self.escape_literal_column(name) + else: + name = self.preparer.quote(name) + table = column.table + if table is None or not include_table or not table.named_with_column: + return name + else: + effective_schema = self.preparer.schema_for_object(table) + + if effective_schema: + schema_prefix = ( + self.preparer.quote_schema(effective_schema) + "." + ) + else: + schema_prefix = "" + tablename = table.name + if isinstance(tablename, elements._truncated_label): + tablename = self._truncated_identifier("alias", tablename) + + return schema_prefix + self.preparer.quote(tablename) + "." + name + + def visit_collation(self, element, **kw): + return self.preparer.format_collation(element.collation) + + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name + + def visit_index(self, index, **kwargs): + return index.name + + def visit_typeclause(self, typeclause, **kw): + kw["type_expression"] = typeclause + kw["identifier_preparer"] = self.preparer + return self.dialect.type_compiler.process(typeclause.type, **kw) + + def post_process_text(self, text): + if self.preparer._double_percents: + text = text.replace("%", "%%") + return text + + def escape_literal_column(self, text): + if self.preparer._double_percents: + text = text.replace("%", "%%") + return text + + def visit_textclause(self, textclause, add_to_result_map=None, **kw): + def do_bindparam(m): + name = m.group(1) + if name in textclause._bindparams: + return self.process(textclause._bindparams[name], **kw) + else: + return self.bindparam_string(name, **kw) + + if not self.stack: + self.isplaintext = True + + if add_to_result_map: + # text() object is present in the columns clause of a + # select(). Add a no-name entry to the result map so that + # row[text()] produces a result + add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE) + + # un-escape any \:params + return BIND_PARAMS_ESC.sub( + lambda m: m.group(1), + BIND_PARAMS.sub( + do_bindparam, self.post_process_text(textclause.text) + ), + ) + + def visit_textual_select( + self, taf, compound_index=None, asfrom=False, **kw + ): + + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + new_entry = { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": taf, + } + self.stack.append(new_entry) + + if taf._independent_ctes: + for cte in taf._independent_ctes: + cte._compiler_dispatch(self, **kw) + + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) + + if populate_result_map: + self._ordered_columns = ( + self._textual_ordered_columns + ) = taf.positional + + # enable looser result column matching when the SQL text links to + # Column objects by name only + self._loose_column_name_matching = not taf.positional and bool( + taf.column_args + ) + + for c in taf.column_args: + self.process( + c, + within_columns_clause=True, + add_to_result_map=self._add_to_result_map, + ) + + text = self.process(taf.element, **kw) + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + def visit_null(self, expr, **kw): + return "NULL" + + def visit_true(self, expr, **kw): + if self.dialect.supports_native_boolean: + return "true" + else: + return "1" + + def visit_false(self, expr, **kw): + if self.dialect.supports_native_boolean: + return "false" + else: + return "0" + + def _generate_delimited_list(self, elements, separator, **kw): + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in elements) + if s + ) + + def _generate_delimited_and_list(self, clauses, **kw): + + lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean( + operators.and_, + elements.True_._singleton, + elements.False_._singleton, + clauses, + ) + if lcc == 1: + return clauses[0]._compiler_dispatch(self, **kw) + else: + separator = OPERATORS[operators.and_] + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in clauses) + if s + ) + + def visit_tuple(self, clauselist, **kw): + return "(%s)" % self.visit_clauselist(clauselist, **kw) + + def visit_clauselist(self, clauselist, **kw): + sep = clauselist.operator + if sep is None: + sep = " " + else: + sep = OPERATORS[clauselist.operator] + + return self._generate_delimited_list(clauselist.clauses, sep, **kw) + + def visit_case(self, clause, **kwargs): + x = "CASE " + if clause.value is not None: + x += clause.value._compiler_dispatch(self, **kwargs) + " " + for cond, result in clause.whens: + x += ( + "WHEN " + + cond._compiler_dispatch(self, **kwargs) + + " THEN " + + result._compiler_dispatch(self, **kwargs) + + " " + ) + if clause.else_ is not None: + x += ( + "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " " + ) + x += "END" + return x + + def visit_type_coerce(self, type_coerce, **kw): + return type_coerce.typed_expression._compiler_dispatch(self, **kw) + + def visit_cast(self, cast, **kwargs): + return "CAST(%s AS %s)" % ( + cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs), + ) + + def _format_frame_clause(self, range_, **kw): + + return "%s AND %s" % ( + "UNBOUNDED PRECEDING" + if range_[0] is elements.RANGE_UNBOUNDED + else "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[0])), **kw),) + if range_[0] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),), + "UNBOUNDED FOLLOWING" + if range_[1] is elements.RANGE_UNBOUNDED + else "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[1])), **kw),) + if range_[1] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),), + ) + + def visit_over(self, over, **kwargs): + if over.range_: + range_ = "RANGE BETWEEN %s" % self._format_frame_clause( + over.range_, **kwargs + ) + elif over.rows: + range_ = "ROWS BETWEEN %s" % self._format_frame_clause( + over.rows, **kwargs + ) + else: + range_ = None + + return "%s OVER (%s)" % ( + over.element._compiler_dispatch(self, **kwargs), + " ".join( + [ + "%s BY %s" + % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ("PARTITION", over.partition_by), + ("ORDER", over.order_by), + ) + if clause is not None and len(clause) + ] + + ([range_] if range_ else []) + ), + ) + + def visit_withingroup(self, withingroup, **kwargs): + return "%s WITHIN GROUP (ORDER BY %s)" % ( + withingroup.element._compiler_dispatch(self, **kwargs), + withingroup.order_by._compiler_dispatch(self, **kwargs), + ) + + def visit_funcfilter(self, funcfilter, **kwargs): + return "%s FILTER (WHERE %s)" % ( + funcfilter.func._compiler_dispatch(self, **kwargs), + funcfilter.criterion._compiler_dispatch(self, **kwargs), + ) + + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s)" % ( + field, + extract.expr._compiler_dispatch(self, **kwargs), + ) + + def visit_scalar_function_column(self, element, **kw): + compiled_fn = self.visit_function(element.fn, **kw) + compiled_col = self.visit_column(element, **kw) + return "(%s).%s" % (compiled_fn, compiled_col) + + def visit_function(self, func, add_to_result_map=None, **kwargs): + if add_to_result_map is not None: + add_to_result_map(func.name, func.name, (), func.type) + + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + if disp: + text = disp(func, **kwargs) + else: + name = FUNCTIONS.get(func._deannotate().__class__, None) + if name: + if func._has_args: + name += "%(expr)s" + else: + name = func.name + name = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, elements.quoted_name) + else name + ) + name = name + "%(expr)s" + text = ".".join( + [ + ( + self.preparer.quote(tok) + if self.preparer._requires_quotes_illegal_chars(tok) + or isinstance(name, elements.quoted_name) + else tok + ) + for tok in func.packagenames + ] + + [name] + ) % {"expr": self.function_argspec(func, **kwargs)} + + if func._with_ordinality: + text += " WITH ORDINALITY" + return text + + def visit_next_value_func(self, next_value, **kw): + return self.visit_sequence(next_value.sequence) + + def visit_sequence(self, sequence, **kw): + raise NotImplementedError( + "Dialect '%s' does not support sequence increments." + % self.dialect.name + ) + + def function_argspec(self, func, **kwargs): + return func.clause_expr._compiler_dispatch(self, **kwargs) + + def visit_compound_select( + self, cs, asfrom=False, compound_index=None, **kwargs + ): + toplevel = not self.stack + + compile_state = cs._compile_state_factory(cs, self, **kwargs) + + if toplevel and not self.compile_state: + self.compile_state = compile_state + + compound_stmt = compile_state.statement + + entry = self._default_stack_entry if toplevel else self.stack[-1] + need_result_map = toplevel or ( + not compound_index + and entry.get("need_result_map_for_compound", False) + ) + + # indicates there is already a CompoundSelect in play + if compound_index == 0: + entry["select_0"] = cs + + self.stack.append( + { + "correlate_froms": entry["correlate_froms"], + "asfrom_froms": entry["asfrom_froms"], + "selectable": cs, + "compile_state": compile_state, + "need_result_map_for_compound": need_result_map, + } + ) + + if compound_stmt._independent_ctes: + for cte in compound_stmt._independent_ctes: + cte._compiler_dispatch(self, **kwargs) + + keyword = self.compound_keywords.get(cs.keyword) + + text = (" " + keyword + " ").join( + ( + c._compiler_dispatch( + self, asfrom=asfrom, compound_index=i, **kwargs + ) + for i, c in enumerate(cs.selects) + ) + ) + + kwargs["include_table"] = False + text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs)) + text += self.order_by_clause(cs, **kwargs) + if cs._has_row_limiting_clause: + text += self._row_limit_clause(cs, **kwargs) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, include_following_stack=True + ) + + text + ) + + self.stack.pop(-1) + return text + + def _row_limit_clause(self, cs, **kwargs): + if cs._fetch_clause is not None: + return self.fetch_clause(cs, **kwargs) + else: + return self.limit_clause(cs, **kwargs) + + def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): + attrname = "visit_%s_%s%s" % ( + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) + return getattr(self, attrname, None) + + def visit_unary( + self, unary, add_to_result_map=None, result_map_targets=(), **kw + ): + + if add_to_result_map is not None: + result_map_targets += (unary,) + kw["add_to_result_map"] = add_to_result_map + kw["result_map_targets"] = result_map_targets + + if unary.operator: + if unary.modifier: + raise exc.CompileError( + "Unary expression does not support operator " + "and modifier simultaneously" + ) + disp = self._get_operator_dispatch( + unary.operator, "unary", "operator" + ) + if disp: + return disp(unary, unary.operator, **kw) + else: + return self._generate_generic_unary_operator( + unary, OPERATORS[unary.operator], **kw + ) + elif unary.modifier: + disp = self._get_operator_dispatch( + unary.modifier, "unary", "modifier" + ) + if disp: + return disp(unary, unary.modifier, **kw) + else: + return self._generate_generic_unary_modifier( + unary, OPERATORS[unary.modifier], **kw + ) + else: + raise exc.CompileError( + "Unary expression has no operator or modifier" + ) + + def visit_is_true_unary_operator(self, element, operator, **kw): + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): + return self.process(element.element, **kw) + else: + return "%s = 1" % self.process(element.element, **kw) + + def visit_is_false_unary_operator(self, element, operator, **kw): + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): + return "NOT %s" % self.process(element.element, **kw) + else: + return "%s = 0" % self.process(element.element, **kw) + + def visit_not_match_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_binary( + binary, override_operator=operators.match_op + ) + + def visit_not_in_op_binary(self, binary, operator, **kw): + # The brackets are required in the NOT IN operation because the empty + # case is handled using the form "(col NOT IN (null) OR 1 = 1)". + # The presence of the OR makes the brackets required. + return "(%s)" % self._generate_generic_binary( + binary, OPERATORS[operator], **kw + ) + + def visit_empty_set_op_expr(self, type_, expand_op): + if expand_op is operators.not_in_op: + if len(type_) > 1: + return "(%s)) OR (1 = 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) OR (1 = 1" + elif expand_op is operators.in_op: + if len(type_) > 1: + return "(%s)) AND (1 != 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) AND (1 != 1" + else: + return self.visit_empty_set_expr(type_) + + def visit_empty_set_expr(self, element_types): + raise NotImplementedError( + "Dialect '%s' does not support empty set expression." + % self.dialect.name + ) + + def _literal_execute_expanding_parameter_literal_binds( + self, parameter, values + ): + + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + + if not values: + if typ_dialect_impl._is_tuple_type: + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + self.visit_empty_set_op_expr( + parameter.type.types, parameter.expand_op + ) + + else: + replacement_expression = self.visit_empty_set_op_expr( + [parameter.type], parameter.expand_op + ) + + elif typ_dialect_impl._is_tuple_type or ( + typ_dialect_impl._isnull + and isinstance(values[0], util.collections_abc.Sequence) + and not isinstance( + values[0], util.string_types + util.binary_types + ) + ): + + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + self.render_literal_value(value, param_type) + for value, param_type in zip( + tuple_element, parameter.type.types + ) + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + replacement_expression = ", ".join( + self.render_literal_value(value, parameter.type) + for value in values + ) + + return (), replacement_expression + + def _literal_execute_expanding_parameter(self, name, parameter, values): + + if parameter.literal_execute: + return self._literal_execute_expanding_parameter_literal_binds( + parameter, values + ) + + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + + if not values: + to_update = [] + if typ_dialect_impl._is_tuple_type: + + replacement_expression = self.visit_empty_set_op_expr( + parameter.type.types, parameter.expand_op + ) + else: + replacement_expression = self.visit_empty_set_op_expr( + [parameter.type], parameter.expand_op + ) + + elif typ_dialect_impl._is_tuple_type or ( + typ_dialect_impl._isnull + and isinstance(values[0], util.collections_abc.Sequence) + and not isinstance( + values[0], util.string_types + util.binary_types + ) + ): + assert not typ_dialect_impl._is_array + to_update = [ + ("%s_%s_%s" % (name, i, j), value) + for i, tuple_element in enumerate(values, 1) + for j, value in enumerate(tuple_element, 1) + ] + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + self.bindtemplate + % {"name": to_update[i * len(tuple_element) + j][0]} + for j, value in enumerate(tuple_element) + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + to_update = [ + ("%s_%s" % (name, i), value) + for i, value in enumerate(values, 1) + ] + replacement_expression = ", ".join( + self.bindtemplate % {"name": key} for key, value in to_update + ) + + return to_update, replacement_expression + + def visit_binary( + self, + binary, + override_operator=None, + eager_grouping=False, + from_linter=None, + lateral_from_linter=None, + **kw + ): + if from_linter and operators.is_comparison(binary.operator): + if lateral_from_linter is not None: + enclosing_lateral = kw["enclosing_lateral"] + lateral_from_linter.edges.update( + itertools.product( + binary.left._from_objects + [enclosing_lateral], + binary.right._from_objects + [enclosing_lateral], + ) + ) + else: + from_linter.edges.update( + itertools.product( + binary.left._from_objects, binary.right._from_objects + ) + ) + + # don't allow "? = ?" to render + if ( + self.ansi_bind_rules + and isinstance(binary.left, elements.BindParameter) + and isinstance(binary.right, elements.BindParameter) + ): + kw["literal_execute"] = True + + operator_ = override_operator or binary.operator + disp = self._get_operator_dispatch(operator_, "binary", None) + if disp: + return disp(binary, operator_, **kw) + else: + try: + opstring = OPERATORS[operator_] + except KeyError as err: + util.raise_( + exc.UnsupportedCompilationError(self, operator_), + replace_context=err, + ) + else: + return self._generate_generic_binary( + binary, + opstring, + from_linter=from_linter, + lateral_from_linter=lateral_from_linter, + **kw + ) + + def visit_function_as_comparison_op_binary(self, element, operator, **kw): + return self.process(element.sql_function, **kw) + + def visit_mod_binary(self, binary, operator, **kw): + if self.preparer._double_percents: + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) + else: + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) + + def visit_custom_op_binary(self, element, operator, **kw): + kw["eager_grouping"] = operator.eager_grouping + return self._generate_generic_binary( + element, + " " + self.escape_literal_column(operator.opstring) + " ", + **kw + ) + + def visit_custom_op_unary_operator(self, element, operator, **kw): + return self._generate_generic_unary_operator( + element, self.escape_literal_column(operator.opstring) + " ", **kw + ) + + def visit_custom_op_unary_modifier(self, element, operator, **kw): + return self._generate_generic_unary_modifier( + element, " " + self.escape_literal_column(operator.opstring), **kw + ) + + def _generate_generic_binary( + self, binary, opstring, eager_grouping=False, **kw + ): + + _in_binary = kw.get("_in_binary", False) + + kw["_in_binary"] = True + kw["_binary_op"] = binary.operator + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) + + if _in_binary and eager_grouping: + text = "(%s)" % text + return text + + def _generate_generic_unary_operator(self, unary, opstring, **kw): + return opstring + unary.element._compiler_dispatch(self, **kw) + + def _generate_generic_unary_modifier(self, unary, opstring, **kw): + return unary.element._compiler_dispatch(self, **kw) + opstring + + @util.memoized_property + def _like_percent_literal(self): + return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE) + + def visit_contains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right).concat(percent) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_contains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right).concat(percent) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_startswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent._rconcat(binary.right) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_startswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent._rconcat(binary.right) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_endswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_endswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_like_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + + # TODO: use ternary here, not "and"/ "or" + return "%s LIKE %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def visit_not_like_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return "%s NOT LIKE %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def visit_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return "lower(%s) LIKE lower(%s)" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def visit_not_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return "lower(%s) NOT LIKE lower(%s)" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def visit_between_op_binary(self, binary, operator, **kw): + symmetric = binary.modifiers.get("symmetric", False) + return self._generate_generic_binary( + binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw + ) + + def visit_not_between_op_binary(self, binary, operator, **kw): + symmetric = binary.modifiers.get("symmetric", False) + return self._generate_generic_binary( + binary, + " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ", + **kw + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expressions" + % self.dialect.name + ) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expressions" + % self.dialect.name + ) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expression replacements" + % self.dialect.name + ) + + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + literal_execute=False, + render_postcompile=False, + **kwargs + ): + if not skip_bind_expression: + impl = bindparam.type.dialect_impl(self.dialect) + if impl._has_bind_expression: + bind_expression = impl.bind_expression(bindparam) + wrapped = self.process( + bind_expression, + skip_bind_expression=True, + within_columns_clause=within_columns_clause, + literal_binds=literal_binds, + literal_execute=literal_execute, + render_postcompile=render_postcompile, + **kwargs + ) + if bindparam.expanding: + # for postcompile w/ expanding, move the "wrapped" part + # of this into the inside + m = re.match( + r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped + ) + wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % ( + m.group(2), + m.group(1), + m.group(3), + ) + return wrapped + + if not literal_binds: + literal_execute = ( + literal_execute + or bindparam.literal_execute + or (within_columns_clause and self.ansi_bind_rules) + ) + post_compile = literal_execute or bindparam.expanding + else: + post_compile = False + + if literal_binds: + ret = self.render_literal_bindparam( + bindparam, within_columns_clause=True, **kwargs + ) + if bindparam.expanding: + ret = "(%s)" % ret + return ret + + name = self._truncate_bindparam(bindparam) + + if name in self.binds: + existing = self.binds[name] + if existing is not bindparam: + if ( + (existing.unique or bindparam.unique) + and not existing.proxy_set.intersection( + bindparam.proxy_set + ) + and not existing._cloned_set.intersection( + bindparam._cloned_set + ) + ): + raise exc.CompileError( + "Bind parameter '%s' conflicts with " + "unique bind parameter of the same name" % name + ) + elif existing.expanding != bindparam.expanding: + raise exc.CompileError( + "Can't reuse bound parameter name '%s' in both " + "'expanding' (e.g. within an IN expression) and " + "non-expanding contexts. If this parameter is to " + "receive a list/array value, set 'expanding=True' on " + "it for expressions that aren't IN, otherwise use " + "a different parameter name." % (name,) + ) + elif existing._is_crud or bindparam._is_crud: + raise exc.CompileError( + "bindparam() name '%s' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using bindparam() " + "with insert() or update() (for example, 'b_%s')." + % (bindparam.key, bindparam.key) + ) + + self.binds[bindparam.key] = self.binds[name] = bindparam + + # if we are given a cache key that we're going to match against, + # relate the bindparam here to one that is most likely present + # in the "extracted params" portion of the cache key. this is used + # to set up a positional mapping that is used to determine the + # correct parameters for a subsequent use of this compiled with + # a different set of parameter values. here, we accommodate for + # parameters that may have been cloned both before and after the cache + # key was been generated. + ckbm = self._cache_key_bind_match + if ckbm: + for bp in bindparam._cloned_set: + if bp.key in ckbm: + cb = ckbm[bp.key] + ckbm[cb].append(bindparam) + + if bindparam.isoutparam: + self.has_out_parameters = True + + if post_compile: + if render_postcompile: + self._render_postcompile = True + + if literal_execute: + self.literal_execute_params |= {bindparam} + else: + self.post_compile_params |= {bindparam} + + ret = self.bindparam_string( + name, + post_compile=post_compile, + expanding=bindparam.expanding, + **kwargs + ) + + if bindparam.expanding: + ret = "(%s)" % ret + return ret + + def render_literal_bindparam( + self, bindparam, render_literal_value=NO_ARG, **kw + ): + if render_literal_value is not NO_ARG: + value = render_literal_value + else: + if bindparam.value is None and bindparam.callable is None: + op = kw.get("_binary_op", None) + if op and op not in (operators.is_, operators.is_not): + util.warn_limited( + "Bound parameter '%s' rendering literal NULL in a SQL " + "expression; comparisons to NULL should not use " + "operators outside of 'is' or 'is not'", + (bindparam.key,), + ) + return self.process(sqltypes.NULLTYPE, **kw) + value = bindparam.effective_value + + if bindparam.expanding: + leep = self._literal_execute_expanding_parameter_literal_binds + to_update, replacement_expr = leep(bindparam, value) + return replacement_expr + else: + return self.render_literal_value(value, bindparam.type) + + def render_literal_value(self, value, type_): + """Render the value of a bind parameter as a quoted literal. + + This is used for statement sections that do not accept bind parameters + on the target driver/database. + + This should be implemented by subclasses using the quoting services + of the DBAPI. + + """ + + processor = type_._cached_literal_processor(self.dialect) + if processor: + return processor(value) + else: + raise NotImplementedError( + "Don't know how to literal-quote value %r" % value + ) + + def _truncate_bindparam(self, bindparam): + if bindparam in self.bind_names: + return self.bind_names[bindparam] + + bind_name = bindparam.key + if isinstance(bind_name, elements._truncated_label): + bind_name = self._truncated_identifier("bindparam", bind_name) + + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + + return bind_name + + def _truncated_identifier(self, ident_class, name): + if (ident_class, name) in self.truncated_names: + return self.truncated_names[(ident_class, name)] + + anonname = name.apply_map(self.anon_map) + + if len(anonname) > self.label_length - 6: + counter = self.truncated_names.get(ident_class, 1) + truncname = ( + anonname[0 : max(self.label_length - 6, 0)] + + "_" + + hex(counter)[2:] + ) + self.truncated_names[ident_class] = counter + 1 + else: + truncname = anonname + self.truncated_names[(ident_class, name)] = truncname + return truncname + + def _anonymize(self, name): + return name % self.anon_map + + def bindparam_string( + self, + name, + positional_names=None, + post_compile=False, + expanding=False, + escaped_from=None, + **kw + ): + + if self.positional: + if positional_names is not None: + positional_names.append(name) + else: + self.positiontup.append(name) + elif not escaped_from: + + if _BIND_TRANSLATE_RE.search(name): + # not quite the translate use case as we want to + # also get a quick boolean if we even found + # unusual characters in the name + new_name = _BIND_TRANSLATE_RE.sub( + lambda m: _BIND_TRANSLATE_CHARS[m.group(0)], + name, + ) + escaped_from = name + name = new_name + + if escaped_from: + if not self.escaped_bind_names: + self.escaped_bind_names = {} + self.escaped_bind_names[escaped_from] = name + if post_compile: + return "__[POSTCOMPILE_%s]" % name + else: + return self.bindtemplate % {"name": name} + + def visit_cte( + self, + cte, + asfrom=False, + ashint=False, + fromhints=None, + visiting_cte=None, + from_linter=None, + **kwargs + ): + self._init_cte_state() + + kwargs["visiting_cte"] = cte + + cte_name = cte.name + + if isinstance(cte_name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte_name) + + is_new_cte = True + embedded_in_current_named_cte = False + + _reference_cte = cte._get_reference_cte() + + if _reference_cte in self.level_name_by_cte: + cte_level, _ = self.level_name_by_cte[_reference_cte] + assert _ == cte_name + else: + cte_level = len(self.stack) if cte.nesting else 1 + + cte_level_name = (cte_level, cte_name) + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] + embedded_in_current_named_cte = visiting_cte is existing_cte + + # we've generated a same-named CTE that we are enclosed in, + # or this is the same CTE. just return the name. + if cte is existing_cte._restates or cte is existing_cte: + is_new_cte = False + elif existing_cte is cte._restates: + # we've generated a same-named CTE that is + # enclosed in us - we take precedence, so + # discard the text for the "inner". + del self.ctes[existing_cte] + + existing_cte_reference_cte = existing_cte._get_reference_cte() + + # TODO: determine if these assertions are correct. they + # pass for current test cases + # assert existing_cte_reference_cte is _reference_cte + # assert existing_cte_reference_cte is existing_cte + + del self.level_name_by_cte[existing_cte_reference_cte] + else: + # if the two CTEs are deep-copy identical, consider them + # the same, **if** they are clones, that is, they came from + # the ORM or other visit method + if ( + cte._is_clone_of is not None + or existing_cte._is_clone_of is not None + ) and cte.compare(existing_cte): + is_new_cte = False + else: + raise exc.CompileError( + "Multiple, unrelated CTEs found with " + "the same name: %r" % cte_name + ) + + if not asfrom and not is_new_cte: + return None + + if cte._cte_alias is not None: + pre_alias_cte = cte._cte_alias + cte_pre_alias_name = cte._cte_alias.name + if isinstance(cte_pre_alias_name, elements._truncated_label): + cte_pre_alias_name = self._truncated_identifier( + "alias", cte_pre_alias_name + ) + else: + pre_alias_cte = cte + cte_pre_alias_name = None + + if is_new_cte: + self.ctes_by_level_name[cte_level_name] = cte + self.level_name_by_cte[_reference_cte] = cte_level_name + + if ( + "autocommit" in cte.element._execution_options + and "autocommit" not in self.execution_options + ): + self.execution_options = self.execution_options.union( + { + "autocommit": cte.element._execution_options[ + "autocommit" + ] + } + ) + + if pre_alias_cte not in self.ctes: + self.visit_cte(pre_alias_cte, **kwargs) + + if not cte_pre_alias_name and cte not in self.ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + if isinstance(cte.element, selectable.Select): + col_source = cte.element + elif isinstance(cte.element, selectable.CompoundSelect): + col_source = cte.element.selects[0] + else: + assert False, "cte should only be against SelectBase" + + # TODO: can we get at the .columns_plus_names collection + # that is already (or will be?) generated for the SELECT + # rather than calling twice? + recur_cols = [ + # TODO: proxy_name is not technically safe, + # see test_cte-> + # test_with_recursive_no_name_currently_buggy. not + # clear what should be done with such a case + fallback_label_name or proxy_name + for ( + _, + proxy_name, + fallback_label_name, + c, + repeated, + ) in (col_source._generate_columns_plus_names(True)) + if not repeated + ] + + text += "(%s)" % ( + ", ".join( + self.preparer.format_label_name( + ident, anon_map=self.anon_map + ) + for ident in recur_cols + ) + ) + + if self.positional: + kwargs["positional_names"] = self.cte_positional[cte] = [] + + assert kwargs.get("subquery", False) is False + + if not self.stack: + # toplevel, this is a stringify of the + # cte directly. just compile the inner + # the way alias() does. + return cte.element._compiler_dispatch( + self, asfrom=asfrom, **kwargs + ) + else: + prefixes = self._generate_prefixes( + cte, cte._prefixes, **kwargs + ) + inner = cte.element._compiler_dispatch( + self, asfrom=True, **kwargs + ) + + text += " AS %s\n(%s)" % (prefixes, inner) + + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs + ) + + self.ctes[cte] = text + + if asfrom: + if from_linter: + from_linter.froms[cte] = cte_name + + if not is_new_cte and embedded_in_current_named_cte: + return self.preparer.format_alias(cte, cte_name) + + if cte_pre_alias_name: + text = self.preparer.format_alias(cte, cte_pre_alias_name) + if self.preparer._requires_quotes(cte_name): + cte_name = self.preparer.quote(cte_name) + text += self.get_render_as_alias_suffix(cte_name) + return text + else: + return self.preparer.format_alias(cte, cte_name) + + def visit_table_valued_alias(self, element, **kw): + if element.joins_implicitly: + kw["from_linter"] = None + if element._is_lateral: + return self.visit_lateral(element, **kw) + else: + return self.visit_alias(element, **kw) + + def visit_table_valued_column(self, element, **kw): + return self.visit_column(element, **kw) + + def visit_alias( + self, + alias, + asfrom=False, + ashint=False, + iscrud=False, + fromhints=None, + subquery=False, + lateral=False, + enclosing_alias=None, + from_linter=None, + **kwargs + ): + + if lateral: + if "enclosing_lateral" not in kwargs: + # if lateral is set and enclosing_lateral is not + # present, we assume we are being called directly + # from visit_lateral() and we need to set enclosing_lateral. + assert alias._is_lateral + kwargs["enclosing_lateral"] = alias + + # for lateral objects, we track a second from_linter that is... + # lateral! to the level above us. + if ( + from_linter + and "lateral_from_linter" not in kwargs + and "enclosing_lateral" in kwargs + ): + kwargs["lateral_from_linter"] = from_linter + + if enclosing_alias is not None and enclosing_alias.element is alias: + inner = alias.element._compiler_dispatch( + self, + asfrom=asfrom, + ashint=ashint, + iscrud=iscrud, + fromhints=fromhints, + lateral=lateral, + enclosing_alias=alias, + **kwargs + ) + if subquery and (asfrom or lateral): + inner = "(%s)" % (inner,) + return inner + else: + enclosing_alias = kwargs["enclosing_alias"] = alias + + if asfrom or ashint: + if isinstance(alias.name, elements._truncated_label): + alias_name = self._truncated_identifier("alias", alias.name) + else: + alias_name = alias.name + + if ashint: + return self.preparer.format_alias(alias, alias_name) + elif asfrom: + if from_linter: + from_linter.froms[alias] = alias_name + + inner = alias.element._compiler_dispatch( + self, asfrom=True, lateral=lateral, **kwargs + ) + if subquery: + inner = "(%s)" % (inner,) + + ret = inner + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name) + ) + + if alias._supports_derived_columns and alias._render_derived: + ret += "(%s)" % ( + ", ".join( + "%s%s" + % ( + self.preparer.quote(col.name), + " %s" + % self.dialect.type_compiler.process( + col.type, **kwargs + ) + if alias._render_derived_w_types + else "", + ) + for col in alias.c + ) + ) + + if fromhints and alias in fromhints: + ret = self.format_from_hint_text( + ret, alias, fromhints[alias], iscrud + ) + + return ret + else: + # note we cancel the "subquery" flag here as well + return alias.element._compiler_dispatch( + self, lateral=lateral, **kwargs + ) + + def visit_subquery(self, subquery, **kw): + kw["subquery"] = True + return self.visit_alias(subquery, **kw) + + def visit_lateral(self, lateral_, **kw): + kw["lateral"] = True + return "LATERAL %s" % self.visit_alias(lateral_, **kw) + + def visit_tablesample(self, tablesample, asfrom=False, **kw): + text = "%s TABLESAMPLE %s" % ( + self.visit_alias(tablesample, asfrom=True, **kw), + tablesample._get_method()._compiler_dispatch(self, **kw), + ) + + if tablesample.seed is not None: + text += " REPEATABLE (%s)" % ( + tablesample.seed._compiler_dispatch(self, **kw) + ) + + return text + + def visit_values(self, element, asfrom=False, from_linter=None, **kw): + kw.setdefault("literal_binds", element.literal_binds) + v = "VALUES %s" % ", ".join( + self.process( + elements.Tuple( + types=element._column_types, *elem + ).self_group(), + **kw + ) + for chunk in element._data + for elem in chunk + ) + + if isinstance(element.name, elements._truncated_label): + name = self._truncated_identifier("values", element.name) + else: + name = element.name + + if element._is_lateral: + lateral = "LATERAL " + else: + lateral = "" + + if asfrom: + if from_linter: + from_linter.froms[element] = ( + name if name is not None else "(unnamed VALUES element)" + ) + + if name: + v = "%s(%s)%s (%s)" % ( + lateral, + v, + self.get_render_as_alias_suffix(self.preparer.quote(name)), + ( + ", ".join( + c._compiler_dispatch( + self, include_table=False, **kw + ) + for c in element.columns + ) + ), + ) + else: + v = "%s(%s)" % (lateral, v) + return v + + def get_render_as_alias_suffix(self, alias_name_text): + return " AS " + alias_name_text + + def _add_to_result_map(self, keyname, name, objects, type_): + if keyname is None or keyname == "*": + self._ordered_columns = False + self._textual_ordered_columns = True + if type_._is_tuple_type: + raise exc.CompileError( + "Most backends don't support SELECTing " + "from a tuple() object. If this is an ORM query, " + "consider using the Bundle object." + ) + self._result_columns.append((keyname, name, objects, type_)) + + def _label_returning_column(self, stmt, column, column_clause_args=None): + """Render a column with necessary labels inside of a RETURNING clause. + + This method is provided for individual dialects in place of calling + the _label_select_column method directly, so that the two use cases + of RETURNING vs. SELECT can be disambiguated going forward. + + .. versionadded:: 1.4.21 + + """ + return self._label_select_column( + None, + column, + True, + False, + {} if column_clause_args is None else column_clause_args, + ) + + def _label_select_column( + self, + select, + column, + populate_result_map, + asfrom, + column_clause_args, + name=None, + proxy_name=None, + fallback_label_name=None, + within_columns_clause=True, + column_is_repeated=False, + need_column_expressions=False, + ): + """produce labeled columns present in a select().""" + impl = column.type.dialect_impl(self.dialect) + + if impl._has_column_expression and ( + need_column_expressions or populate_result_map + ): + col_expr = impl.column_expression(column) + else: + col_expr = column + + if populate_result_map: + # pass an "add_to_result_map" callable into the compilation + # of embedded columns. this collects information about the + # column as it will be fetched in the result and is coordinated + # with cursor.description when the query is executed. + add_to_result_map = self._add_to_result_map + + # if the SELECT statement told us this column is a repeat, + # wrap the callable with one that prevents the addition of the + # targets + if column_is_repeated: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map(keyname, name, (), type_) + + # if we redefined col_expr for type expressions, wrap the + # callable with one that adds the original column to the targets + elif col_expr is not column: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map( + keyname, name, (column,) + objects, type_ + ) + + else: + add_to_result_map = None + + # this method is used by some of the dialects for RETURNING, + # which has different inputs. _label_returning_column was added + # as the better target for this now however for 1.4 we will keep + # _label_select_column directly compatible with this use case. + # these assertions right now set up the current expected inputs + assert within_columns_clause, ( + "_label_select_column is only relevant within " + "the columns clause of a SELECT or RETURNING" + ) + + if isinstance(column, elements.Label): + if col_expr is not column: + result_expr = _CompileLabel( + col_expr, column.name, alt_names=(column.element,) + ) + else: + result_expr = col_expr + + elif name: + # here, _columns_plus_names has determined there's an explicit + # label name we need to use. this is the default for + # tablenames_plus_columnnames as well as when columns are being + # deduplicated on name + + assert ( + proxy_name is not None + ), "proxy_name is required if 'name' is passed" + + result_expr = _CompileLabel( + col_expr, + name, + alt_names=( + proxy_name, + # this is a hack to allow legacy result column lookups + # to work as they did before; this goes away in 2.0. + # TODO: this only seems to be tested indirectly + # via test/orm/test_deprecations.py. should be a + # resultset test for this + column._tq_label, + ), + ) + else: + # determine here whether this column should be rendered in + # a labelled context or not, as we were given no required label + # name from the caller. Here we apply heuristics based on the kind + # of SQL expression involved. + + if col_expr is not column: + # type-specific expression wrapping the given column, + # so we render a label + render_with_label = True + elif isinstance(column, elements.ColumnClause): + # table-bound column, we render its name as a label if we are + # inside of a subquery only + render_with_label = ( + asfrom + and not column.is_literal + and column.table is not None + ) + elif isinstance(column, elements.TextClause): + render_with_label = False + elif isinstance(column, elements.UnaryExpression): + render_with_label = column.wraps_column_expression or asfrom + elif ( + # general class of expressions that don't have a SQL-column + # addressible name. includes scalar selects, bind parameters, + # SQL functions, others + not isinstance(column, elements.NamedColumn) + # deeper check that indicates there's no natural "name" to + # this element, which accommodates for custom SQL constructs + # that might have a ".name" attribute (but aren't SQL + # functions) but are not implementing this more recently added + # base class. in theory the "NamedColumn" check should be + # enough, however here we seek to maintain legacy behaviors + # as well. + and column._non_anon_label is None + ): + render_with_label = True + else: + render_with_label = False + + if render_with_label: + if not fallback_label_name: + # used by the RETURNING case right now. we generate it + # here as 3rd party dialects may be referring to + # _label_select_column method directly instead of the + # just-added _label_returning_column method + assert not column_is_repeated + fallback_label_name = column._anon_name_label + + fallback_label_name = ( + elements._truncated_label(fallback_label_name) + if not isinstance( + fallback_label_name, elements._truncated_label + ) + else fallback_label_name + ) + + result_expr = _CompileLabel( + col_expr, fallback_label_name, alt_names=(proxy_name,) + ) + else: + result_expr = col_expr + + column_clause_args.update( + within_columns_clause=within_columns_clause, + add_to_result_map=add_to_result_map, + ) + return result_expr._compiler_dispatch(self, **column_clause_args) + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + hinttext = self.get_from_hint_text(table, hint) + if hinttext: + sqltext += " " + hinttext + return sqltext + + def get_select_hint_text(self, byfroms): + return None + + def get_from_hint_text(self, table, text): + return None + + def get_crud_hint_text(self, table, text): + return None + + def get_statement_hint_text(self, hint_texts): + return " ".join(hint_texts) + + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) + + def _display_froms_for_select( + self, select_stmt, asfrom, lateral=False, **kw + ): + # utility method to help external dialects + # get the correct from list for a select. + # specifically the oracle dialect needs this feature + # right now. + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + compile_state = select_stmt._compile_state_factory(select_stmt, self) + + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] + + if asfrom and not lateral: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms + ), + implicit_correlate_froms=(), + ) + else: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms, + ) + return froms + + translate_select_structure = None + """if not ``None``, should be a callable which accepts ``(select_stmt, + **kw)`` and returns a select object. this is used for structural changes + mostly to accommodate for LIMIT/OFFSET schemes + + """ + + def visit_select( + self, + select_stmt, + asfrom=False, + insert_into=False, + fromhints=None, + compound_index=None, + select_wraps_for=None, + lateral=False, + from_linter=None, + **kwargs + ): + assert select_wraps_for is None, ( + "SQLAlchemy 1.4 requires use of " + "the translate_select_structure hook for structural " + "translations of SELECT objects" + ) + + # initial setup of SELECT. the compile_state_factory may now + # be creating a totally different SELECT from the one that was + # passed in. for ORM use this will convert from an ORM-state + # SELECT to a regular "Core" SELECT. other composed operations + # such as computation of joins will be performed. + + kwargs["within_columns_clause"] = False + + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + + toplevel = not self.stack + + if toplevel and not self.compile_state: + self.compile_state = compile_state + + is_embedded_select = compound_index is not None or insert_into + + # translate step for Oracle, SQL Server which often need to + # restructure the SELECT to allow for LIMIT/OFFSET and possibly + # other conditions + if self.translate_select_structure: + new_select_stmt = self.translate_select_structure( + select_stmt, asfrom=asfrom, **kwargs + ) + + # if SELECT was restructured, maintain a link to the originals + # and assemble a new compile state + if new_select_stmt is not select_stmt: + compile_state_wraps_for = compile_state + select_wraps_for = select_stmt + select_stmt = new_select_stmt + + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + + entry = self._default_stack_entry if toplevel else self.stack[-1] + + populate_result_map = need_column_expressions = ( + toplevel + or entry.get("need_result_map_for_compound", False) + or entry.get("need_result_map_for_nested", False) + ) + + # indicates there is a CompoundSelect in play and we are not the + # first select + if compound_index: + populate_result_map = False + + # this was first proposed as part of #3372; however, it is not + # reached in current tests and could possibly be an assertion + # instead. + if not populate_result_map and "add_to_result_map" in kwargs: + del kwargs["add_to_result_map"] + + froms = self._setup_select_stack( + select_stmt, compile_state, entry, asfrom, lateral, compound_index + ) + + column_clause_args = kwargs.copy() + column_clause_args.update( + {"within_label_clause": False, "within_columns_clause": False} + ) + + text = "SELECT " # we're off to a good start ! + + if select_stmt._hints: + hint_text, byfrom = self._setup_select_hints(select_stmt) + if hint_text: + text += hint_text + " " + else: + byfrom = None + + if select_stmt._independent_ctes: + for cte in select_stmt._independent_ctes: + cte._compiler_dispatch(self, **kwargs) + + if select_stmt._prefixes: + text += self._generate_prefixes( + select_stmt, select_stmt._prefixes, **kwargs + ) + + text += self.get_select_precolumns(select_stmt, **kwargs) + # the actual list of columns to print in the SELECT column list. + inner_columns = [ + c + for c in [ + self._label_select_column( + select_stmt, + column, + populate_result_map, + asfrom, + column_clause_args, + name=name, + proxy_name=proxy_name, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + need_column_expressions=need_column_expressions, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in compile_state.columns_plus_names + ] + if c is not None + ] + + if populate_result_map and select_wraps_for is not None: + # if this select was generated from translate_select, + # rewrite the targeted columns in the result map + + translate = dict( + zip( + [ + name + for ( + key, + proxy_name, + fallback_label_name, + name, + repeated, + ) in compile_state.columns_plus_names + ], + [ + name + for ( + key, + proxy_name, + fallback_label_name, + name, + repeated, + ) in compile_state_wraps_for.columns_plus_names + ], + ) + ) + + self._result_columns = [ + (key, name, tuple(translate.get(o, o) for o in obj), type_) + for key, name, obj, type_ in self._result_columns + ] + + text = self._compose_select_body( + text, + select_stmt, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, + ) + + if select_stmt._statement_hints: + per_dialect = [ + ht + for (dialect_name, ht) in select_stmt._statement_hints + if dialect_name in ("*", self.dialect.name) + ] + if per_dialect: + text += " " + self.get_statement_hint_text(per_dialect) + + if self.ctes: + # In compound query, CTEs are shared at the compound level + if not is_embedded_select: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause(nesting_level=nesting_level) + text + ) + + if select_stmt._suffixes: + text += " " + self._generate_prefixes( + select_stmt, select_stmt._suffixes, **kwargs + ) + + self.stack.pop(-1) + + return text + + def _setup_select_hints(self, select): + byfrom = dict( + [ + ( + from_, + hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)}, + ) + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) + hint_text = self.get_select_hint_text(byfrom) + return hint_text, byfrom + + def _setup_select_stack( + self, select, compile_state, entry, asfrom, lateral, compound_index + ): + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] + + if compound_index == 0: + entry["select_0"] = select + elif compound_index: + select_0 = entry["select_0"] + numcols = len(select_0._all_selected_columns) + + if len(compile_state.columns_plus_names) != numcols: + raise exc.CompileError( + "All selectables passed to " + "CompoundSelect must have identical numbers of " + "columns; select #%d has %d columns, select " + "#%d has %d" + % ( + 1, + numcols, + compound_index + 1, + len(select._all_selected_columns), + ) + ) + + if asfrom and not lateral: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms + ), + implicit_correlate_froms=(), + ) + else: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms, + ) + + new_correlate_froms = set(selectable._from_objects(*froms)) + all_correlate_froms = new_correlate_froms.union(correlate_froms) + + new_entry = { + "asfrom_froms": new_correlate_froms, + "correlate_froms": all_correlate_froms, + "selectable": select, + "compile_state": compile_state, + } + self.stack.append(new_entry) + + return froms + + def _compose_select_body( + self, + text, + select, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, + ): + text += ", ".join(inner_columns) + + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + + if froms: + text += " \nFROM " + + if select._hints: + text += ", ".join( + [ + f._compiler_dispatch( + self, + asfrom=True, + fromhints=byfrom, + from_linter=from_linter, + **kwargs + ) + for f in froms + ] + ) + else: + text += ", ".join( + [ + f._compiler_dispatch( + self, + asfrom=True, + from_linter=from_linter, + **kwargs + ) + for f in froms + ] + ) + else: + text += self.default_from() + + if select._where_criteria: + t = self._generate_delimited_and_list( + select._where_criteria, from_linter=from_linter, **kwargs + ) + if t: + text += " \nWHERE " + t + + if warn_linting: + from_linter.warn() + + if select._group_by_clauses: + text += self.group_by_clause(select, **kwargs) + + if select._having_criteria: + t = self._generate_delimited_and_list( + select._having_criteria, **kwargs + ) + if t: + text += " \nHAVING " + t + + if select._order_by_clauses: + text += self.order_by_clause(select, **kwargs) + + if select._has_row_limiting_clause: + text += self._row_limit_clause(select, **kwargs) + + if select._for_update_arg is not None: + text += self.for_update_clause(select, **kwargs) + + return text + + def _generate_prefixes(self, stmt, prefixes, **kw): + clause = " ".join( + prefix._compiler_dispatch(self, **kw) + for prefix, dialect_name in prefixes + if dialect_name is None or dialect_name == self.dialect.name + ) + if clause: + clause += " " + return clause + + def _render_cte_clause( + self, + nesting_level=None, + include_following_stack=False, + ): + """ + include_following_stack + Also render the nesting CTEs on the next stack. Useful for + SQL structures like UNION or INSERT that can wrap SELECT + statements containing nesting CTEs. + """ + if not self.ctes: + return "" + + if nesting_level and nesting_level > 1: + ctes = util.OrderedDict() + for cte in list(self.ctes.keys()): + cte_level, cte_name = self.level_name_by_cte[ + cte._get_reference_cte() + ] + is_rendered_level = cte_level == nesting_level or ( + include_following_stack and cte_level == nesting_level + 1 + ) + if not (cte.nesting and is_rendered_level): + continue + + ctes[cte] = self.ctes[cte] + + else: + ctes = self.ctes + + if not ctes: + return "" + + ctes_recursive = any([cte.recursive for cte in ctes]) + + if self.positional: + self.positiontup = ( + sum([self.cte_positional[cte] for cte in ctes], []) + + self.positiontup + ) + cte_text = self.get_cte_preamble(ctes_recursive) + " " + cte_text += ", \n".join([txt for txt in ctes.values()]) + cte_text += "\n " + + if nesting_level and nesting_level > 1: + for cte in list(ctes.keys()): + cte_level, cte_name = self.level_name_by_cte[ + cte._get_reference_cte() + ] + del self.ctes[cte] + del self.ctes_by_level_name[(cte_level, cte_name)] + del self.level_name_by_cte[cte._get_reference_cte()] + + return cte_text + + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + + def get_select_precolumns(self, select, **kw): + """Called when building a ``SELECT`` statement, position is just + before column list. + + """ + if select._distinct_on: + util.warn_deprecated( + "DISTINCT ON is currently supported only by the PostgreSQL " + "dialect. Use of DISTINCT ON for other backends is currently " + "silently ignored, however this usage is deprecated, and will " + "raise CompileError in a future release for all backends " + "that do not support this syntax.", + version="1.4", + ) + return "DISTINCT " if select._distinct else "" + + def group_by_clause(self, select, **kw): + """allow dialects to customize how GROUP BY is rendered.""" + + group_by = self._generate_delimited_list( + select._group_by_clauses, OPERATORS[operators.comma_op], **kw + ) + if group_by: + return " GROUP BY " + group_by + else: + return "" + + def order_by_clause(self, select, **kw): + """allow dialects to customize how ORDER BY is rendered.""" + + order_by = self._generate_delimited_list( + select._order_by_clauses, OPERATORS[operators.comma_op], **kw + ) + + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def for_update_clause(self, select, **kw): + return " FOR UPDATE" + + def returning_clause(self, stmt, returning_cols): + raise exc.CompileError( + "RETURNING is not supported by this " + "dialect's statement compiler." + ) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += "\n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT -1" + text += " OFFSET " + self.process(select._offset_clause, **kw) + return text + + def fetch_clause(self, select, **kw): + text = "" + if select._offset_clause is not None: + text += "\n OFFSET %s ROWS" % self.process( + select._offset_clause, **kw + ) + if select._fetch_clause is not None: + text += "\n FETCH FIRST %s%s ROWS %s" % ( + self.process(select._fetch_clause, **kw), + " PERCENT" if select._fetch_clause_options["percent"] else "", + "WITH TIES" + if select._fetch_clause_options["with_ties"] + else "ONLY", + ) + return text + + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + from_linter=None, + **kwargs + ): + if from_linter: + from_linter.froms[table] = table.fullname + + if asfrom or ashint: + effective_schema = self.preparer.schema_for_object(table) + + if use_schema and effective_schema: + ret = ( + self.preparer.quote_schema(effective_schema) + + "." + + self.preparer.quote(table.name) + ) + else: + ret = self.preparer.quote(table.name) + if fromhints and table in fromhints: + ret = self.format_from_hint_text( + ret, table, fromhints[table], iscrud + ) + return ret + else: + return "" + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.update( + itertools.product( + join.left._from_objects, join.right._from_objects + ) + ) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " JOIN " + return ( + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + join_type + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + " ON " + # TODO: likely need asfrom=True here? + + join.onclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) + ) + + def _setup_crud_hints(self, stmt, table_text): + dialect_hints = dict( + [ + (table, hint_text) + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + ] + ) + if stmt.table in dialect_hints: + table_text = self.format_from_hint_text( + table_text, stmt.table, dialect_hints[stmt.table], True + ) + return dialect_hints, table_text + + def visit_insert(self, insert_stmt, **kw): + + compile_state = insert_stmt._compile_state_factory( + insert_stmt, self, **kw + ) + insert_stmt = compile_state.statement + + toplevel = not self.stack + + if toplevel: + self.isinsert = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + self.stack.append( + { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": insert_stmt, + } + ) + + crud_params = crud._get_crud_params( + self, insert_stmt, compile_state, **kw + ) + + if ( + not crud_params + and not self.dialect.supports_default_values + and not self.dialect.supports_default_metavalue + and not self.dialect.supports_empty_insert + ): + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name + ) + + if compile_state._has_multi_parameters: + if not self.dialect.supports_multivalues_insert: + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support " + "in-place multirow inserts." % self.dialect.name + ) + crud_params_single = crud_params[0] + else: + crud_params_single = crud_params + + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT " + + if insert_stmt._prefixes: + text += self._generate_prefixes( + insert_stmt, insert_stmt._prefixes, **kw + ) + + text += "INTO " + table_text = preparer.format_table(insert_stmt.table) + + if insert_stmt._hints: + _, table_text = self._setup_crud_hints(insert_stmt, table_text) + + if insert_stmt._independent_ctes: + for cte in insert_stmt._independent_ctes: + cte._compiler_dispatch(self, **kw) + + text += table_text + + if crud_params_single or not supports_default_values: + text += " (%s)" % ", ".join( + [expr for c, expr, value in crud_params_single] + ) + + if self.returning or insert_stmt._returning: + returning_clause = self.returning_clause( + insert_stmt, self.returning or insert_stmt._returning + ) + + if self.returning_precedes_values: + text += " " + returning_clause + else: + returning_clause = None + + if insert_stmt.select is not None: + # placed here by crud.py + select_text = self.process( + self.stack[-1]["insert_from_select"], insert_into=True, **kw + ) + + if self.ctes and self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text += " %s%s" % ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ), + select_text, + ) + else: + text += " %s" % select_text + elif not crud_params and supports_default_values: + text += " DEFAULT VALUES" + elif compile_state._has_multi_parameters: + text += " VALUES %s" % ( + ", ".join( + "(%s)" + % (", ".join(value for c, expr, value in crud_param_set)) + for crud_param_set in crud_params + ) + ) + else: + insert_single_values_expr = ", ".join( + [value for c, expr, value in crud_params] + ) + text += " VALUES (%s)" % insert_single_values_expr + if toplevel: + self.insert_single_values_expr = insert_single_values_expr + + if insert_stmt._post_values_clause is not None: + post_values_clause = self.process( + insert_stmt._post_values_clause, **kw + ) + if post_values_clause: + text += " " + post_values_clause + + if returning_clause and not self.returning_precedes_values: + text += " " + returning_clause + + if self.ctes and not self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, include_following_stack=True + ) + + text + ) + + self.stack.pop(-1) + + return text + + def update_limit_clause(self, update_stmt): + """Provide a hook for MySQL to add LIMIT to the UPDATE""" + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + """Provide a hook to override the initial table clause + in an UPDATE statement. + + MySQL overrides this. + + """ + kw["asfrom"] = True + return from_table._compiler_dispatch(self, iscrud=True, **kw) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Provide a hook to override the generation of an + UPDATE..FROM clause. + + MySQL and MSSQL override this. + + """ + raise NotImplementedError( + "This backend does not support multiple-table " + "criteria within UPDATE" + ) + + def visit_update(self, update_stmt, **kw): + compile_state = update_stmt._compile_state_factory( + update_stmt, self, **kw + ) + update_stmt = compile_state.statement + + toplevel = not self.stack + if toplevel: + self.isupdate = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + extra_froms = compile_state._extra_froms + is_multitable = bool(extra_froms) + + if is_multitable: + # main table might be a JOIN + main_froms = set(selectable._from_objects(update_stmt.table)) + render_extra_froms = [ + f for f in extra_froms if f not in main_froms + ] + correlate_froms = main_froms.union(extra_froms) + else: + render_extra_froms = [] + correlate_froms = {update_stmt.table} + + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) + + text = "UPDATE " + + if update_stmt._prefixes: + text += self._generate_prefixes( + update_stmt, update_stmt._prefixes, **kw + ) + + table_text = self.update_tables_clause( + update_stmt, update_stmt.table, render_extra_froms, **kw + ) + crud_params = crud._get_crud_params( + self, update_stmt, compile_state, **kw + ) + + if update_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + update_stmt, table_text + ) + else: + dialect_hints = None + + if update_stmt._independent_ctes: + for cte in update_stmt._independent_ctes: + cte._compiler_dispatch(self, **kw) + + text += table_text + + text += " SET " + text += ", ".join(expr + "=" + value for c, expr, value in crud_params) + + if self.returning or update_stmt._returning: + if self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, self.returning or update_stmt._returning + ) + + if extra_froms: + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + render_extra_froms, + dialect_hints, + **kw + ) + if extra_from_text: + text += " " + extra_from_text + + if update_stmt._where_criteria: + t = self._generate_delimited_and_list( + update_stmt._where_criteria, **kw + ) + if t: + text += " WHERE " + t + + limit_clause = self.update_limit_clause(update_stmt) + if limit_clause: + text += " " + limit_clause + + if ( + self.returning or update_stmt._returning + ) and not self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, self.returning or update_stmt._returning + ) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Provide a hook to override the generation of an + DELETE..FROM clause. + + This can be used to implement DELETE..USING for example. + + MySQL and MSSQL override this. + + """ + raise NotImplementedError( + "This backend does not support multiple-table " + "criteria within DELETE" + ) + + def delete_table_clause(self, delete_stmt, from_table, extra_froms): + return from_table._compiler_dispatch(self, asfrom=True, iscrud=True) + + def visit_delete(self, delete_stmt, **kw): + compile_state = delete_stmt._compile_state_factory( + delete_stmt, self, **kw + ) + delete_stmt = compile_state.statement + + toplevel = not self.stack + if toplevel: + self.isdelete = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + extra_froms = compile_state._extra_froms + + correlate_froms = {delete_stmt.table}.union(extra_froms) + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": delete_stmt, + } + ) + + text = "DELETE " + + if delete_stmt._prefixes: + text += self._generate_prefixes( + delete_stmt, delete_stmt._prefixes, **kw + ) + + text += "FROM " + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) + + if delete_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + delete_stmt, table_text + ) + else: + dialect_hints = None + + if delete_stmt._independent_ctes: + for cte in delete_stmt._independent_ctes: + cte._compiler_dispatch(self, **kw) + + text += table_text + + if delete_stmt._returning: + if self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, delete_stmt._returning + ) + + if extra_froms: + extra_from_text = self.delete_extra_from_clause( + delete_stmt, + delete_stmt.table, + extra_froms, + dialect_hints, + **kw + ) + if extra_from_text: + text += " " + extra_from_text + + if delete_stmt._where_criteria: + t = self._generate_delimited_and_list( + delete_stmt._where_criteria, **kw + ) + if t: + text += " WHERE " + t + + if delete_stmt._returning and not self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, delete_stmt._returning + ) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + def visit_savepoint(self, savepoint_stmt): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_release_savepoint(self, savepoint_stmt): + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + +class StrSQLCompiler(SQLCompiler): + """A :class:`.SQLCompiler` subclass which allows a small selection + of non-standard SQL features to render into a string value. + + The :class:`.StrSQLCompiler` is invoked whenever a Core expression + element is directly stringified without calling upon the + :meth:`_expression.ClauseElement.compile` method. + It can render a limited set + of non-standard SQL constructs to assist in basic stringification, + however for more substantial custom or dialect-specific SQL constructs, + it will be necessary to make use of + :meth:`_expression.ClauseElement.compile` + directly. + + .. seealso:: + + :ref:`faq_sql_expression_string` + + """ + + def _fallback_column_name(self, column): + return "" + + @util.preload_module("sqlalchemy.engine.url") + def visit_unsupported_compilation(self, element, err, **kw): + if element.stringify_dialect != "default": + url = util.preloaded.engine_url + dialect = url.URL.create(element.stringify_dialect).get_dialect()() + + compiler = dialect.statement_compiler(dialect, None) + if not isinstance(compiler, StrSQLCompiler): + return compiler.process(element) + + return super(StrSQLCompiler, self).visit_unsupported_compilation( + element, err + ) + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self.visit_getitem_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self.visit_getitem_binary(binary, operator, **kw) + + def visit_sequence(self, seq, **kw): + return "" % self.preparer.format_sequence(seq) + + def returning_clause(self, stmt, returning_cols): + columns = [ + self._label_select_column(None, c, True, False, {}) + for c in base._select_iterables(returning_cols) + ] + + return "RETURNING " + ", ".join(columns) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return ", " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def visit_empty_set_expr(self, type_): + return "SELECT 1 WHERE 1!=1" + + def get_from_hint_text(self, table, text): + return "[%s]" % text + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " ", **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " ", **kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + replacement = binary.modifiers["replacement"] + return "(%s, %s, %s)" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + replacement._compiler_dispatch(self, **kw), + ) + + +class DDLCompiler(Compiled): + @util.memoized_property + def sql_compiler(self): + return self.dialect.statement_compiler( + self.dialect, None, schema_translate_map=self.schema_translate_map + ) + + @util.memoized_property + def type_compiler(self): + return self.dialect.type_compiler + + def construct_params( + self, params=None, extracted_parameters=None, escape_names=True + ): + return None + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.target, schema.Table): + context = context.copy() + + preparer = self.preparer + path = preparer.format_table_seq(ddl.target) + if len(path) == 1: + table, sch = path[0], "" + else: + table, sch = path[-1], path[0] + + context.setdefault("table", table) + context.setdefault("schema", sch) + context.setdefault("fullname", preparer.format_table(ddl.target)) + + return self.sql_compiler.post_process_text(ddl.statement % context) + + def visit_create_schema(self, create, **kw): + schema = self.preparer.format_schema(create.element) + return "CREATE SCHEMA " + schema + + def visit_drop_schema(self, drop, **kw): + schema = self.preparer.format_schema(drop.element) + text = "DROP SCHEMA " + schema + if drop.cascade: + text += " CASCADE" + return text + + def visit_create_table(self, create, **kw): + table = create.element + preparer = self.preparer + + text = "\nCREATE " + if table._prefixes: + text += " ".join(table._prefixes) + " " + + text += "TABLE " + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += preparer.format_table(table) + " " + + create_table_suffix = self.create_table_suffix(table) + if create_table_suffix: + text += create_table_suffix + " " + + text += "(" + + separator = "\n" + + # if only one primary key, specify it along with the column + first_pk = False + for create_column in create.columns: + column = create_column.element + try: + processed = self.process( + create_column, first_pk=column.primary_key and not first_pk + ) + if processed is not None: + text += separator + separator = ", \n" + text += "\t" + processed + if column.primary_key: + first_pk = True + except exc.CompileError as ce: + util.raise_( + exc.CompileError( + util.u("(in table '%s', column '%s'): %s") + % (table.description, column.name, ce.args[0]) + ), + from_=ce, + ) + + const = self.create_table_constraints( + table, + _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa + ) + if const: + text += separator + "\t" + const + + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def visit_create_column(self, create, first_pk=False, **kw): + column = create.element + + if column.system: + return None + + text = self.get_column_specification(column, first_pk=first_pk) + const = " ".join( + self.process(constraint) for constraint in column.constraints + ) + if const: + text += " " + const + + return text + + def create_table_constraints( + self, table, _include_foreign_key_constraints=None, **kw + ): + + # On some DB order is significant: visit PK first, then the + # other constraints (engine.ReflectionTest.testbasic failed on FB2) + constraints = [] + if table.primary_key: + constraints.append(table.primary_key) + + all_fkcs = table.foreign_key_constraints + if _include_foreign_key_constraints is not None: + omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints) + else: + omit_fkcs = set() + + constraints.extend( + [ + c + for c in table._sorted_constraints + if c is not table.primary_key and c not in omit_fkcs + ] + ) + + return ", \n\t".join( + p + for p in ( + self.process(constraint) + for constraint in constraints + if ( + constraint._create_rule is None + or constraint._create_rule(self) + ) + and ( + not self.dialect.supports_alter + or not getattr(constraint, "use_alter", False) + ) + ) + if p is not None + ) + + def visit_drop_table(self, drop, **kw): + text = "\nDROP TABLE " + if drop.if_exists: + text += "IF EXISTS " + return text + self.preparer.format_table(drop.element) + + def visit_drop_view(self, drop, **kw): + return "\nDROP VIEW " + self.preparer.format_table(drop.element) + + def _verify_index_table(self, index): + if index.table is None: + raise exc.CompileError( + "Index '%s' is not associated " "with any table." % index.name + ) + + def visit_create_index( + self, create, include_schema=False, include_table_schema=True, **kw + ): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.name is None: + raise exc.CompileError( + "CREATE INDEX requires that the index have a name" + ) + + text += "INDEX " + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table( + index.table, use_schema=include_table_schema + ), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + return text + + def visit_drop_index(self, drop, **kw): + index = drop.element + + if index.name is None: + raise exc.CompileError( + "DROP INDEX requires that the index have a name" + ) + text = "\nDROP INDEX " + if drop.if_exists: + text += "IF EXISTS " + + return text + self._prepared_index_name(index, include_schema=True) + + def _prepared_index_name(self, index, include_schema=False): + if index.table is not None: + effective_schema = self.preparer.schema_for_object(index.table) + else: + effective_schema = None + if include_schema and effective_schema: + schema_name = self.preparer.quote_schema(effective_schema) + else: + schema_name = None + + index_name = self.preparer.format_index(index) + + if schema_name: + index_name = schema_name + "." + index_name + return index_name + + def visit_add_constraint(self, create, **kw): + return "ALTER TABLE %s ADD %s" % ( + self.preparer.format_table(create.element.table), + self.process(create.element), + ) + + def visit_set_table_comment(self, create, **kw): + return "COMMENT ON TABLE %s IS %s" % ( + self.preparer.format_table(create.element), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_table_comment(self, drop, **kw): + return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( + drop.element + ) + + def visit_set_column_comment(self, create, **kw): + return "COMMENT ON COLUMN %s IS %s" % ( + self.preparer.format_column( + create.element, use_table=True, use_schema=True + ), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_column_comment(self, drop, **kw): + return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( + drop.element, use_table=True + ) + + def get_identity_options(self, identity_options): + text = [] + if identity_options.increment is not None: + text.append("INCREMENT BY %d" % identity_options.increment) + if identity_options.start is not None: + text.append("START WITH %d" % identity_options.start) + if identity_options.minvalue is not None: + text.append("MINVALUE %d" % identity_options.minvalue) + if identity_options.maxvalue is not None: + text.append("MAXVALUE %d" % identity_options.maxvalue) + if identity_options.nominvalue is not None: + text.append("NO MINVALUE") + if identity_options.nomaxvalue is not None: + text.append("NO MAXVALUE") + if identity_options.cache is not None: + text.append("CACHE %d" % identity_options.cache) + if identity_options.order is not None: + text.append("ORDER" if identity_options.order else "NO ORDER") + if identity_options.cycle is not None: + text.append("CYCLE" if identity_options.cycle else "NO CYCLE") + return " ".join(text) + + def visit_create_sequence(self, create, prefix=None, **kw): + text = "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) + if prefix: + text += prefix + if create.element.start is None: + create.element.start = self.dialect.default_sequence_base + options = self.get_identity_options(create.element) + if options: + text += " " + options + return text + + def visit_drop_sequence(self, drop, **kw): + return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element) + + def visit_drop_constraint(self, drop, **kw): + constraint = drop.element + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + else: + formatted_name = None + + if formatted_name is None: + raise exc.CompileError( + "Can't emit DROP CONSTRAINT for constraint %r; " + "it has no name" % drop.element + ) + return "ALTER TABLE %s DROP CONSTRAINT %s%s" % ( + self.preparer.format_table(drop.element.table), + formatted_name, + drop.cascade and " CASCADE" or "", + ) + + def get_column_specification(self, column, **kwargs): + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + ) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if column.computed is not None: + colspec += " " + self.process(column.computed) + + if ( + column.identity is not None + and self.dialect.supports_identity_columns + ): + colspec += " " + self.process(column.identity) + + if not column.nullable and ( + not column.identity or not self.dialect.supports_identity_columns + ): + colspec += " NOT NULL" + return colspec + + def create_table_suffix(self, table): + return "" + + def post_create_table(self, table): + return "" + + def get_column_default_string(self, column): + if isinstance(column.server_default, schema.DefaultClause): + if isinstance(column.server_default.arg, util.string_types): + return self.sql_compiler.render_literal_value( + column.server_default.arg, sqltypes.STRINGTYPE + ) + else: + return self.sql_compiler.process( + column.server_default.arg, literal_binds=True + ) + else: + return None + + def visit_table_or_column_check_constraint(self, constraint, **kw): + if constraint.is_column_level: + return self.visit_column_check_constraint(constraint) + else: + return self.visit_check_constraint(constraint) + + def visit_check_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_column_check_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_primary_key_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "PRIMARY KEY " + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in ( + constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns + ) + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_foreign_key_constraint(self, constraint, **kw): + preparer = self.preparer + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + remote_table = list(constraint.elements)[0].column.table + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + ", ".join( + preparer.quote(f.parent.name) for f in constraint.elements + ), + self.define_constraint_remote_table( + constraint, remote_table, preparer + ), + ", ".join( + preparer.quote(f.column.name) for f in constraint.elements + ), + ) + text += self.define_constraint_match(constraint) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table) + + def visit_unique_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "UNIQUE (%s)" % ( + ", ".join(self.preparer.quote(c.name) for c in constraint) + ) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % self.preparer.validate_sql_phrase( + constraint.ondelete, FK_ON_DELETE + ) + if constraint.onupdate is not None: + text += " ON UPDATE %s" % self.preparer.validate_sql_phrase( + constraint.onupdate, FK_ON_UPDATE + ) + return text + + def define_constraint_deferrability(self, constraint): + text = "" + if constraint.deferrable is not None: + if constraint.deferrable: + text += " DEFERRABLE" + else: + text += " NOT DEFERRABLE" + if constraint.initially is not None: + text += " INITIALLY %s" % self.preparer.validate_sql_phrase( + constraint.initially, FK_INITIALLY + ) + return text + + def define_constraint_match(self, constraint): + text = "" + if constraint.match is not None: + text += " MATCH %s" % constraint.match + return text + + def visit_computed_column(self, generated, **kw): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + text += " STORED" + elif generated.persisted is False: + text += " VIRTUAL" + return text + + def visit_identity_column(self, identity, **kw): + text = "GENERATED %s AS IDENTITY" % ( + "ALWAYS" if identity.always else "BY DEFAULT", + ) + options = self.get_identity_options(identity) + if options: + text += " (%s)" % options + return text + + +class GenericTypeCompiler(TypeCompiler): + def visit_FLOAT(self, type_, **kw): + return "FLOAT" + + def visit_REAL(self, type_, **kw): + return "REAL" + + def visit_NUMERIC(self, type_, **kw): + if type_.precision is None: + return "NUMERIC" + elif type_.scale is None: + return "NUMERIC(%(precision)s)" % {"precision": type_.precision} + else: + return "NUMERIC(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } + + def visit_DECIMAL(self, type_, **kw): + if type_.precision is None: + return "DECIMAL" + elif type_.scale is None: + return "DECIMAL(%(precision)s)" % {"precision": type_.precision} + else: + return "DECIMAL(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } + + def visit_INTEGER(self, type_, **kw): + return "INTEGER" + + def visit_SMALLINT(self, type_, **kw): + return "SMALLINT" + + def visit_BIGINT(self, type_, **kw): + return "BIGINT" + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP" + + def visit_DATETIME(self, type_, **kw): + return "DATETIME" + + def visit_DATE(self, type_, **kw): + return "DATE" + + def visit_TIME(self, type_, **kw): + return "TIME" + + def visit_CLOB(self, type_, **kw): + return "CLOB" + + def visit_NCLOB(self, type_, **kw): + return "NCLOB" + + def _render_string_type(self, type_, name): + + text = name + if type_.length: + text += "(%d)" % type_.length + if type_.collation: + text += ' COLLATE "%s"' % type_.collation + return text + + def visit_CHAR(self, type_, **kw): + return self._render_string_type(type_, "CHAR") + + def visit_NCHAR(self, type_, **kw): + return self._render_string_type(type_, "NCHAR") + + def visit_VARCHAR(self, type_, **kw): + return self._render_string_type(type_, "VARCHAR") + + def visit_NVARCHAR(self, type_, **kw): + return self._render_string_type(type_, "NVARCHAR") + + def visit_TEXT(self, type_, **kw): + return self._render_string_type(type_, "TEXT") + + def visit_BLOB(self, type_, **kw): + return "BLOB" + + def visit_BINARY(self, type_, **kw): + return "BINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_VARBINARY(self, type_, **kw): + return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_BOOLEAN(self, type_, **kw): + return "BOOLEAN" + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) + + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) + + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) + + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) + + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) + + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) + + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) + + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) + + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) + + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) + + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) + + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_null(self, type_, **kw): + raise exc.CompileError( + "Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_ + ) + + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) + + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) + + +class StrSQLTypeCompiler(GenericTypeCompiler): + def process(self, type_, **kw): + try: + _compiler_dispatch = type_._compiler_dispatch + except AttributeError: + return self._visit_unknown(type_, **kw) + else: + return _compiler_dispatch(self, **kw) + + def __getattr__(self, key): + if key.startswith("visit_"): + return self._visit_unknown + else: + raise AttributeError(key) + + def _visit_unknown(self, type_, **kw): + if type_.__class__.__name__ == type_.__class__.__name__.upper(): + return type_.__class__.__name__ + else: + return repr(type_) + + def visit_null(self, type_, **kw): + return "NULL" + + def visit_user_defined(self, type_, **kw): + try: + get_col_spec = type_.get_col_spec + except AttributeError: + return repr(type_) + else: + return get_col_spec(**kw) + + +class IdentifierPreparer(object): + + """Handle quoting and case-folding of identifiers based on options.""" + + reserved_words = RESERVED_WORDS + + legal_characters = LEGAL_CHARACTERS + + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + + schema_for_object = operator.attrgetter("schema") + """Return the .schema attribute for an object. + + For the default IdentifierPreparer, the schema for an object is always + the value of the ".schema" attribute. if the preparer is replaced + with one that has a non-empty schema_translate_map, the value of the + ".schema" attribute is rendered a symbol that will be converted to a + real schema name from the mapping post-compile. + + """ + + def __init__( + self, + dialect, + initial_quote='"', + final_quote=None, + escape_quote='"', + quote_case_sensitive_collations=True, + omit_schema=False, + ): + """Construct a new ``IdentifierPreparer`` object. + + initial_quote + Character that begins a delimited identifier. + + final_quote + Character that ends a delimited identifier. Defaults to + `initial_quote`. + + omit_schema + Prevent prepending schema name. Useful for databases that do + not support schemae. + """ + + self.dialect = dialect + self.initial_quote = initial_quote + self.final_quote = final_quote or self.initial_quote + self.escape_quote = escape_quote + self.escape_to_quote = self.escape_quote * 2 + self.omit_schema = omit_schema + self.quote_case_sensitive_collations = quote_case_sensitive_collations + self._strings = {} + self._double_percents = self.dialect.paramstyle in ( + "format", + "pyformat", + ) + + def _with_schema_translate(self, schema_translate_map): + prep = self.__class__.__new__(self.__class__) + prep.__dict__.update(self.__dict__) + + def symbol_getter(obj): + name = obj.schema + if name in schema_translate_map and obj._use_schema_map: + if name is not None and ("[" in name or "]" in name): + raise exc.CompileError( + "Square bracket characters ([]) not supported " + "in schema translate name '%s'" % name + ) + return quoted_name( + "__[SCHEMA_%s]" % (name or "_none"), quote=False + ) + else: + return obj.schema + + prep.schema_for_object = symbol_getter + return prep + + def _render_schema_translates(self, statement, schema_translate_map): + d = schema_translate_map + if None in d: + d["_none"] = d[None] + + def replace(m): + name = m.group(2) + effective_schema = d[name] + if not effective_schema: + effective_schema = self.dialect.default_schema_name + if not effective_schema: + # TODO: no coverage here + raise exc.CompileError( + "Dialect has no default schema name; can't " + "use None as dynamic schema target." + ) + return self.quote_schema(effective_schema) + + return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement) + + def _escape_identifier(self, value): + """Escape an identifier. + + Subclasses should override this to provide database-dependent + escaping behavior. + """ + + value = value.replace(self.escape_quote, self.escape_to_quote) + if self._double_percents: + value = value.replace("%", "%%") + return value + + def _unescape_identifier(self, value): + """Canonicalize an escaped identifier. + + Subclasses should override this to provide database-dependent + unescaping behavior that reverses _escape_identifier. + """ + + return value.replace(self.escape_to_quote, self.escape_quote) + + def validate_sql_phrase(self, element, reg): + """keyword sequence filter. + + a filter for elements that are intended to represent keyword sequences, + such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters + should be present. + + .. versionadded:: 1.3 + + """ + + if element is not None and not reg.match(element): + raise exc.CompileError( + "Unexpected SQL phrase: %r (matching against %r)" + % (element, reg.pattern) + ) + return element + + def quote_identifier(self, value): + """Quote an identifier. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + + return ( + self.initial_quote + + self._escape_identifier(value) + + self.final_quote + ) + + def _requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + or (lc_value != value) + ) + + def _requires_quotes_illegal_chars(self, value): + """Return True if the given identifier requires quoting, but + not taking case convention into account.""" + return not self.legal_characters.match(util.text_type(value)) + + def quote_schema(self, schema, force=None): + """Conditionally quote a schema name. + + + The name is quoted if it is a reserved word, contains quote-necessary + characters, or is an instance of :class:`.quoted_name` which includes + ``quote`` set to ``True``. + + Subclasses can override this to provide database-dependent + quoting behavior for schema names. + + :param schema: string schema name + :param force: unused + + .. deprecated:: 0.9 + + The :paramref:`.IdentifierPreparer.quote_schema.force` + parameter is deprecated and will be removed in a future + release. This flag has no effect on the behavior of the + :meth:`.IdentifierPreparer.quote` method; please refer to + :class:`.quoted_name`. + + """ + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote_schema.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + # deprecated 0.9. warning from 1.3 + version="0.9", + ) + + return self.quote(schema) + + def quote(self, ident, force=None): + """Conditionally quote an identifier. + + The identifier is quoted if it is a reserved word, contains + quote-necessary characters, or is an instance of + :class:`.quoted_name` which includes ``quote`` set to ``True``. + + Subclasses can override this to provide database-dependent + quoting behavior for identifier names. + + :param ident: string identifier + :param force: unused + + .. deprecated:: 0.9 + + The :paramref:`.IdentifierPreparer.quote.force` + parameter is deprecated and will be removed in a future + release. This flag has no effect on the behavior of the + :meth:`.IdentifierPreparer.quote` method; please refer to + :class:`.quoted_name`. + + """ + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + # deprecated 0.9. warning from 1.3 + version="0.9", + ) + + force = getattr(ident, "quote", None) + + if force is None: + if ident in self._strings: + return self._strings[ident] + else: + if self._requires_quotes(ident): + self._strings[ident] = self.quote_identifier(ident) + else: + self._strings[ident] = ident + return self._strings[ident] + elif force: + return self.quote_identifier(ident) + else: + return ident + + def format_collation(self, collation_name): + if self.quote_case_sensitive_collations: + return self.quote(collation_name) + else: + return collation_name + + def format_sequence(self, sequence, use_schema=True): + name = self.quote(sequence.name) + + effective_schema = self.schema_for_object(sequence) + + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): + name = self.quote_schema(effective_schema) + "." + name + return name + + def format_label(self, label, name=None): + return self.quote(name or label.name) + + def format_alias(self, alias, name=None): + return self.quote(name or alias.name) + + def format_savepoint(self, savepoint, name=None): + # Running the savepoint name through quoting is unnecessary + # for all known dialects. This is here to support potential + # third party use cases + ident = name or savepoint.ident + if self._requires_quotes(ident): + ident = self.quote_identifier(ident) + return ident + + @util.preload_module("sqlalchemy.sql.naming") + def format_constraint(self, constraint, _alembic_quote=True): + naming = util.preloaded.sql_naming + + if constraint.name is elements._NONE_NAME: + name = naming._constraint_name_for_table( + constraint, constraint.table + ) + + if name is None: + return None + else: + name = constraint.name + + if constraint.__visit_name__ == "index": + return self.truncate_and_render_index_name( + name, _alembic_quote=_alembic_quote + ) + else: + return self.truncate_and_render_constraint_name( + name, _alembic_quote=_alembic_quote + ) + + def truncate_and_render_index_name(self, name, _alembic_quote=True): + # calculate these at format time so that ad-hoc changes + # to dialect.max_identifier_length etc. can be reflected + # as IdentifierPreparer is long lived + max_ = ( + self.dialect.max_index_name_length + or self.dialect.max_identifier_length + ) + return self._truncate_and_render_maxlen_name( + name, max_, _alembic_quote + ) + + def truncate_and_render_constraint_name(self, name, _alembic_quote=True): + # calculate these at format time so that ad-hoc changes + # to dialect.max_identifier_length etc. can be reflected + # as IdentifierPreparer is long lived + max_ = ( + self.dialect.max_constraint_name_length + or self.dialect.max_identifier_length + ) + return self._truncate_and_render_maxlen_name( + name, max_, _alembic_quote + ) + + def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote): + if isinstance(name, elements._truncated_label): + if len(name) > max_: + name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] + else: + self.dialect.validate_identifier(name) + + if not _alembic_quote: + return name + else: + return self.quote(name) + + def format_index(self, index): + return self.format_constraint(index) + + def format_table(self, table, use_schema=True, name=None): + """Prepare a quoted table and schema name.""" + + if name is None: + name = table.name + + result = self.quote(name) + + effective_schema = self.schema_for_object(table) + + if not self.omit_schema and use_schema and effective_schema: + result = self.quote_schema(effective_schema) + "." + result + return result + + def format_schema(self, name): + """Prepare a quoted schema name.""" + + return self.quote(name) + + def format_label_name( + self, + name, + anon_map=None, + ): + """Prepare a quoted column name.""" + + if anon_map is not None and isinstance( + name, elements._truncated_label + ): + name = name.apply_map(anon_map) + + return self.quote(name) + + def format_column( + self, + column, + use_table=False, + name=None, + table_name=None, + use_schema=False, + anon_map=None, + ): + """Prepare a quoted column name.""" + + if name is None: + name = column.name + + if anon_map is not None and isinstance( + name, elements._truncated_label + ): + name = name.apply_map(anon_map) + + if not getattr(column, "is_literal", False): + if use_table: + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + self.quote(name) + ) + else: + return self.quote(name) + else: + # literal textual elements get stuck into ColumnClause a lot, + # which shouldn't get quoted + + if use_table: + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + name + ) + else: + return name + + def format_table_seq(self, table, use_schema=True): + """Format table name and schema as a tuple.""" + + # Dialects with more levels in their fully qualified references + # ('database', 'owner', etc.) could override this and return + # a longer sequence. + + effective_schema = self.schema_for_object(table) + + if not self.omit_schema and use_schema and effective_schema: + return ( + self.quote_schema(effective_schema), + self.format_table(table, use_schema=False), + ) + else: + return (self.format_table(table, use_schema=False),) + + @util.memoized_property + def _r_identifiers(self): + initial, final, escaped_final = [ + re.escape(s) + for s in ( + self.initial_quote, + self.final_quote, + self._escape_identifier(self.final_quote), + ) + ] + r = re.compile( + r"(?:" + r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" + r"|([^\.]+))(?=\.|$))+" + % {"initial": initial, "final": final, "escaped": escaped_final} + ) + return r + + def unformat_identifiers(self, identifiers): + """Unpack 'schema.table.column'-like strings into components.""" + + r = self._r_identifiers + return [ + self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)] + ] diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py new file mode 100644 index 0000000..920c8b3 --- /dev/null +++ b/lib/sqlalchemy/sql/crud.py @@ -0,0 +1,1091 @@ +# sql/crud.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Functions used by compiler.py to determine the parameters rendered +within INSERT and UPDATE statements. + +""" +import functools +import operator + +from . import coercions +from . import dml +from . import elements +from . import roles +from .selectable import Select +from .. import exc +from .. import util + +REQUIRED = util.symbol( + "REQUIRED", + """ +Placeholder for the value within a :class:`.BindParameter` +which is required to be present when the statement is passed +to :meth:`_engine.Connection.execute`. + +This symbol is typically used when a :func:`_expression.insert` +or :func:`_expression.update` statement is compiled without parameter +values present. + +""", +) + + +def _get_crud_params(compiler, stmt, compile_state, **kw): + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately + populating the CursorResult's prefetch_cols() and postfetch_cols() + collections. + + """ + + compiler.postfetch = [] + compiler.insert_prefetch = [] + compiler.update_prefetch = [] + compiler.returning = [] + + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + ( + _column_as_key, + _getattr_col_key, + _col_bind_name, + ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state) + + compiler._key_getters_for_crud_column = getters + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if compiler.column_keys is None and compile_state._no_parameters: + return [ + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + ) + for c in stmt.table.columns + ] + + if compile_state._has_multi_parameters: + spd = compile_state._multi_parameters[0] + stmt_parameter_tuples = list(spd.items()) + elif compile_state._ordered_values: + spd = compile_state._dict_parameters + stmt_parameter_tuples = compile_state._ordered_values + elif compile_state._dict_parameters: + spd = compile_state._dict_parameters + stmt_parameter_tuples = list(spd.items()) + else: + stmt_parameter_tuples = spd = None + + # if we have statement parameters - set defaults in the + # compiled params + if compiler.column_keys is None: + parameters = {} + elif stmt_parameter_tuples: + parameters = dict( + (_column_as_key(key), REQUIRED) + for key in compiler.column_keys + if key not in spd + ) + else: + parameters = dict( + (_column_as_key(key), REQUIRED) for key in compiler.column_keys + ) + + # create a list of column assignment clauses as tuples + values = [] + + if stmt_parameter_tuples is not None: + _get_stmt_parameter_tuples_params( + compiler, + compile_state, + parameters, + stmt_parameter_tuples, + _column_as_key, + values, + kw, + ) + + check_columns = {} + + # special logic that only occurs for multi-table UPDATE + # statements + if compile_state.isupdate and compile_state.is_multitable: + _get_update_multitable_params( + compiler, + stmt, + compile_state, + stmt_parameter_tuples, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, + ) + + if compile_state.isinsert and stmt._select_names: + _scan_insert_from_select_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, + ) + else: + _scan_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, + ) + + if parameters and stmt_parameter_tuples: + check = ( + set(parameters) + .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples) + .difference(check_columns) + ) + if check: + raise exc.CompileError( + "Unconsumed column names: %s" + % (", ".join("%s" % (c,) for c in check)) + ) + + if compile_state._has_multi_parameters: + values = _extend_values_for_multiparams( + compiler, + stmt, + compile_state, + values, + _column_as_key, + kw, + ) + elif ( + not values + and compiler.for_executemany + and compiler.dialect.supports_default_metavalue + ): + # convert an "INSERT DEFAULT VALUES" + # into INSERT (firstcol) VALUES (DEFAULT) which can be turned + # into an in-place multi values. This supports + # insert_executemany_returning mode :) + values = [ + ( + stmt.table.columns[0], + compiler.preparer.format_column(stmt.table.columns[0]), + "DEFAULT", + ) + ] + + return values + + +def _create_bind_param( + compiler, col, value, process=True, required=False, name=None, **kw +): + if name is None: + name = col.key + bindparam = elements.BindParameter( + name, value, type_=col.type, required=required + ) + bindparam._is_crud = True + if process: + bindparam = bindparam._compiler_dispatch(compiler, **kw) + return bindparam + + +def _handle_values_anonymous_param(compiler, col, value, name, **kw): + # the insert() and update() constructs as of 1.4 will now produce anonymous + # bindparam() objects in the values() collections up front when given plain + # literal values. This is so that cache key behaviors, which need to + # produce bound parameters in deterministic order without invoking any + # compilation here, can be applied to these constructs when they include + # values() (but not yet multi-values, which are not included in caching + # right now). + # + # in order to produce the desired "crud" style name for these parameters, + # which will also be targetable in engine/default.py through the usual + # conventions, apply our desired name to these unique parameters by + # populating the compiler truncated names cache with the desired name, + # rather than having + # compiler.visit_bindparam()->compiler._truncated_identifier make up a + # name. Saves on call counts also. + + # for INSERT/UPDATE that's a CTE, we don't need names to match to + # external parameters and these would also conflict in the case where + # multiple insert/update are combined together using CTEs + is_cte = "visiting_cte" in kw + + if ( + not is_cte + and value.unique + and isinstance(value.key, elements._truncated_label) + ): + compiler.truncated_names[("bindparam", value.key)] = name + + if value.type._isnull: + # either unique parameter, or other bound parameters that were + # passed in directly + # set type to that of the column unconditionally + value = value._with_binary_element_type(col.type) + + return value._compiler_dispatch(compiler, **kw) + + +def _key_getters_for_crud_column(compiler, stmt, compile_state): + if compile_state.isupdate and compile_state._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(compile_state._extra_froms) + + c_key_role = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + + def _column_as_key(key): + str_key = c_key_role(key) + if hasattr(key, "table") and key.table in _et: + return (key.table.name, str_key) + else: + return str_key + + def _getattr_col_key(col): + if col.table in _et: + return (col.table.name, col.key) + else: + return col.key + + def _col_bind_name(col): + if col.table in _et: + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + _getattr_col_key = _col_bind_name = operator.attrgetter("key") + + return _column_as_key, _getattr_col_key, _col_bind_name + + +def _scan_insert_from_select_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, +): + + ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + ) = _get_returning_modifiers(compiler, stmt, compile_state) + + cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] + + assert compiler.stack[-1]["selectable"] is stmt + + compiler.stack[-1]["insert_from_select"] = stmt.select + + add_select_cols = [] + if stmt.include_insert_from_select_defaults: + col_set = set(cols) + for col in stmt.table.columns: + if col not in col_set and col.default: + cols.append(col) + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + parameters.pop(col_key) + values.append((c, compiler.preparer.format_column(c), None)) + else: + _append_param_insert_select_hasdefault( + compiler, stmt, c, add_select_cols, kw + ) + + if add_select_cols: + values.extend(add_select_cols) + ins_from_select = compiler.stack[-1]["insert_from_select"] + if not isinstance(ins_from_select, Select): + raise exc.CompileError( + "Can't extend statement for INSERT..FROM SELECT to include " + "additional default-holding column(s) " + "%s. Convert the selectable to a subquery() first, or pass " + "include_defaults=False to Insert.from_select() to skip these " + "columns." + % (", ".join(repr(key) for _, key, _ in add_select_cols),) + ) + ins_from_select = ins_from_select._generate() + # copy raw_columns + ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [ + expr for col, col_expr, expr in add_select_cols + ] + compiler.stack[-1]["insert_from_select"] = ins_from_select + + +def _scan_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + kw, +): + ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + ) = _get_returning_modifiers(compiler, stmt, compile_state) + + if compile_state._parameter_ordering: + parameter_ordering = [ + _column_as_key(key) for key in compile_state._parameter_ordering + ] + ordered_keys = set(parameter_ordering) + cols = [ + stmt.table.c[key] + for key in parameter_ordering + if isinstance(key, util.string_types) and key in stmt.table.c + ] + [c for c in stmt.table.c if c.key not in ordered_keys] + + else: + cols = stmt.table.columns + + for c in cols: + # scan through every column in the target table + + col_key = _getattr_col_key(c) + + if col_key in parameters and col_key not in check_columns: + # parameter is present for the column. use that. + + _append_param_parameter( + compiler, + stmt, + compile_state, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + values, + kw, + ) + + elif compile_state.isinsert: + # no parameter is present and it's an insert. + + if c.primary_key and need_pks: + # it's a primary key column, it will need to be generated by a + # default generator of some kind, and the statement expects + # inserted_primary_key to be available. + + if implicit_returning: + # we can use RETURNING, find out how to invoke this + # column and get the value where RETURNING is an option. + # we can inline server-side functions in this case. + + _append_param_insert_pk_returning( + compiler, stmt, c, values, kw + ) + else: + # otherwise, find out how to invoke this column + # and get its value where RETURNING is not an option. + # if we have to invoke a server-side function, we need + # to pre-execute it. or if this is a straight + # autoincrement column and the dialect supports it + # we can use cursor.lastrowid. + + _append_param_insert_pk_no_returning( + compiler, stmt, c, values, kw + ) + + elif c.default is not None: + # column has a default, but it's not a pk column, or it is but + # we don't need to get the pk back. + _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw + ) + + elif c.server_default is not None: + # column has a DDL-level default, and is either not a pk + # column or we don't need the pk. + if implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + elif implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + elif ( + c.primary_key + and c is not stmt.table._autoincrement_column + and not c.nullable + ): + _warn_pk_with_no_anticipated_value(c) + + elif compile_state.isupdate: + # no parameter is present and it's an insert. + + _append_param_update( + compiler, + compile_state, + stmt, + c, + implicit_return_defaults, + values, + kw, + ) + + +def _append_param_parameter( + compiler, + stmt, + compile_state, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + values, + kw, +): + value = parameters.pop(col_key) + + col_value = compiler.preparer.format_column( + c, use_table=compile_state.include_table_with_column_exprs + ) + + if coercions._is_literal(value): + value = _create_bind_param( + compiler, + c, + value, + required=value is REQUIRED, + name=_col_bind_name(c) + if not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c), + **kw + ) + elif value._is_bind_parameter: + value = _handle_values_anonymous_param( + compiler, + c, + value, + name=_col_bind_name(c) + if not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c), + **kw + ) + else: + # value is a SQL expression + value = compiler.process(value.self_group(), **kw) + + if compile_state.isupdate: + if implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + + else: + compiler.postfetch.append(c) + else: + if c.primary_key: + + if implicit_returning: + compiler.returning.append(c) + elif compiler.dialect.postfetch_lastrowid: + compiler.postfetch_lastrowid = True + + elif implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + + else: + # postfetch specifically means, "we can SELECT the row we just + # inserted by primary key to get back the server generated + # defaults". so by definition this can't be used to get the + # primary key value back, because we need to have it ahead of + # time. + + compiler.postfetch.append(c) + + values.append((c, col_value, value)) + + +def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): + """Create a primary key expression in the INSERT statement where + we want to populate result.inserted_primary_key and RETURNING + is available. + + """ + if c.default is not None: + if c.default.is_sequence: + if compiler.dialect.supports_sequences and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ): + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) + compiler.returning.append(c) + elif c.default.is_clause_element: + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) + ) + compiler.returning.append(c) + else: + # client side default. OK we can't use RETURNING, need to + # do a "prefetch", which in fact fetches the default value + # on the Python side + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) + elif c is stmt.table._autoincrement_column or c.server_default is not None: + compiler.returning.append(c) + elif not c.nullable: + # no .default, no .server_default, not autoincrement, we have + # no indication this primary key column will have any value + _warn_pk_with_no_anticipated_value(c) + + +def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw): + """Create a primary key expression in the INSERT statement where + we want to populate result.inserted_primary_key and we cannot use + RETURNING. + + Depending on the kind of default here we may create a bound parameter + in the INSERT statement and pre-execute a default generation function, + or we may use cursor.lastrowid if supported by the dialect. + + + """ + + if ( + # column has a Python-side default + c.default is not None + and ( + # and it either is not a sequence, or it is and we support + # sequences and want to invoke it + not c.default.is_sequence + or ( + compiler.dialect.supports_sequences + and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ) + ) + ) + ) or ( + # column is the "autoincrement column" + c is stmt.table._autoincrement_column + and ( + # dialect can't use cursor.lastrowid + not compiler.dialect.postfetch_lastrowid + and ( + # column has a Sequence and we support those + ( + c.default is not None + and c.default.is_sequence + and compiler.dialect.supports_sequences + ) + or + # column has no default on it, but dialect can run the + # "autoincrement" mechanism explicitly, e.g. PostgreSQL + # SERIAL we know the sequence name + ( + c.default is None + and compiler.dialect.preexecute_autoincrement_sequences + ) + ) + ) + ): + # do a pre-execute of the default + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) + elif ( + c.default is None + and c.server_default is None + and not c.nullable + and c is not stmt.table._autoincrement_column + ): + # no .default, no .server_default, not autoincrement, we have + # no indication this primary key column will have any value + _warn_pk_with_no_anticipated_value(c) + elif compiler.dialect.postfetch_lastrowid: + # finally, where it seems like there will be a generated primary key + # value and we haven't set up any other way to fetch it, and the + # dialect supports cursor.lastrowid, switch on the lastrowid flag so + # that the DefaultExecutionContext calls upon cursor.lastrowid + compiler.postfetch_lastrowid = True + + +def _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw +): + if c.default.is_sequence: + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default, **kw), + ) + ) + if implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + elif c.default.is_clause_element: + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process(c.default.arg.self_group(), **kw), + ) + ) + + if implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + elif not c.primary_key: + # don't add primary key column to postfetch + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + ) + ) + + +def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw): + + if c.default.is_sequence: + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): + values.append( + (c, compiler.preparer.format_column(c), c.default.next_value()) + ) + elif c.default.is_clause_element: + values.append( + (c, compiler.preparer.format_column(c), c.default.arg.self_group()) + ) + else: + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param( + compiler, c, process=False, **kw + ), + ) + ) + + +def _append_param_update( + compiler, compile_state, stmt, c, implicit_return_defaults, values, kw +): + + include_table = compile_state.include_table_with_column_exprs + if c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + ( + c, + compiler.preparer.format_column( + c, + use_table=include_table, + ), + compiler.process(c.onupdate.arg.self_group(), **kw), + ) + ) + if implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + else: + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.preparer.format_column( + c, + use_table=include_table, + ), + _create_update_prefetch_bind_param(compiler, c, **kw), + ) + ) + elif c.server_onupdate is not None: + if implicit_return_defaults and c in implicit_return_defaults: + compiler.returning.append(c) + else: + compiler.postfetch.append(c) + elif ( + implicit_return_defaults + and (stmt._return_defaults_columns or not stmt._return_defaults) + and c in implicit_return_defaults + ): + compiler.returning.append(c) + + +def _create_insert_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) + compiler.insert_prefetch.append(c) + return param + + +def _create_update_prefetch_bind_param( + compiler, c, process=True, name=None, **kw +): + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) + compiler.update_prefetch.append(c) + return param + + +class _multiparam_column(elements.ColumnElement): + _is_multiparam_column = True + + def __init__(self, original, index): + self.index = index + self.key = "%s_m%d" % (original.key, index + 1) + self.original = original + self.default = original.default + self.type = original.type + + def compare(self, other, **kw): + raise NotImplementedError() + + def _copy_internals(self, other, **kw): + raise NotImplementedError() + + def __eq__(self, other): + return ( + isinstance(other, _multiparam_column) + and other.key == self.key + and other.original == self.original + ) + + +def _process_multiparam_default_bind(compiler, stmt, c, index, kw): + if not c.default: + raise exc.CompileError( + "INSERT value for column %s is explicitly rendered as a bound" + "parameter in the VALUES clause; " + "a Python-side value or SQL expression is required" % c + ) + elif c.default.is_clause_element: + return compiler.process(c.default.arg.self_group(), **kw) + elif c.default.is_sequence: + # these conditions would have been established + # by append_param_insert_(?:hasdefault|pk_returning|pk_no_returning) + # in order for us to be here, so these don't need to be + # checked + # assert compiler.dialect.supports_sequences and ( + # not c.default.optional + # or not compiler.dialect.sequences_optional + # ) + return compiler.process(c.default, **kw) + else: + col = _multiparam_column(c, index) + if isinstance(stmt, dml.Insert): + return _create_insert_prefetch_bind_param(compiler, col, **kw) + else: + return _create_update_prefetch_bind_param(compiler, col, **kw) + + +def _get_update_multitable_params( + compiler, + stmt, + compile_state, + stmt_parameter_tuples, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, +): + normalized_params = dict( + (coercions.expect(roles.DMLColumnRole, c), param) + for c, param in stmt_parameter_tuples + ) + + include_table = compile_state.include_table_with_column_exprs + + affected_tables = set() + for t in compile_state._extra_froms: + for c in t.c: + if c in normalized_params: + affected_tables.add(t) + check_columns[_getattr_col_key(c)] = c + value = normalized_params[c] + + col_value = compiler.process(c, include_table=include_table) + if coercions._is_literal(value): + value = _create_bind_param( + compiler, + c, + value, + required=value is REQUIRED, + name=_col_bind_name(c), + **kw # TODO: no test coverage for literal binds here + ) + elif value._is_bind_parameter: + value = _handle_values_anonymous_param( + compiler, c, value, name=_col_bind_name(c), **kw + ) + else: + compiler.postfetch.append(c) + value = compiler.process(value.self_group(), **kw) + values.append((c, col_value, value)) + # determine tables which are actually to be updated - process onupdate + # and server_onupdate for these + for t in affected_tables: + for c in t.c: + if c in normalized_params: + continue + elif c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + ( + c, + compiler.process(c, include_table=include_table), + compiler.process( + c.onupdate.arg.self_group(), **kw + ), + ) + ) + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.process(c, include_table=include_table), + _create_update_prefetch_bind_param( + compiler, c, name=_col_bind_name(c), **kw + ), + ) + ) + elif c.server_onupdate is not None: + compiler.postfetch.append(c) + + +def _extend_values_for_multiparams( + compiler, + stmt, + compile_state, + values, + _column_as_key, + kw, +): + values_0 = values + values = [values] + + for i, row in enumerate(compile_state._multi_parameters[1:]): + extension = [] + + row = {_column_as_key(key): v for key, v in row.items()} + + for (col, col_expr, param) in values_0: + if col.key in row: + key = col.key + + if coercions._is_literal(row[key]): + new_param = _create_bind_param( + compiler, + col, + row[key], + name="%s_m%d" % (col.key, i + 1), + **kw + ) + else: + new_param = compiler.process(row[key].self_group(), **kw) + else: + new_param = _process_multiparam_default_bind( + compiler, stmt, col, i, kw + ) + + extension.append((col, col_expr, new_param)) + + values.append(extension) + + return values + + +def _get_stmt_parameter_tuples_params( + compiler, + compile_state, + parameters, + stmt_parameter_tuples, + _column_as_key, + values, + kw, +): + + for k, v in stmt_parameter_tuples: + colkey = _column_as_key(k) + if colkey is not None: + parameters.setdefault(colkey, v) + else: + # a non-Column expression on the left side; + # add it to values() in an "as-is" state, + # coercing right side to bound param + + # note one of the main use cases for this is array slice + # updates on PostgreSQL, as the left side is also an expression. + + col_expr = compiler.process( + k, include_table=compile_state.include_table_with_column_exprs + ) + + if coercions._is_literal(v): + v = compiler.process( + elements.BindParameter(None, v, type_=k.type), **kw + ) + else: + if v._is_bind_parameter and v.type._isnull: + # either unique parameter, or other bound parameters that + # were passed in directly + # set type to that of the column unconditionally + v = v._with_binary_element_type(k.type) + + v = compiler.process(v.self_group(), **kw) + + values.append((k, col_expr, v)) + + +def _get_returning_modifiers(compiler, stmt, compile_state): + + need_pks = ( + compile_state.isinsert + and not stmt._inline + and ( + not compiler.for_executemany + or ( + compiler.dialect.insert_executemany_returning + and stmt._return_defaults + ) + ) + and not stmt._returning + and not compile_state._has_multi_parameters + ) + + implicit_returning = ( + need_pks + and compiler.dialect.implicit_returning + and stmt.table.implicit_returning + ) + + if compile_state.isinsert: + implicit_return_defaults = implicit_returning and stmt._return_defaults + elif compile_state.isupdate: + implicit_return_defaults = ( + compiler.dialect.implicit_returning + and stmt.table.implicit_returning + and stmt._return_defaults + ) + else: + # this line is unused, currently we are always + # isinsert or isupdate + implicit_return_defaults = False # pragma: no cover + + if implicit_return_defaults: + if not stmt._return_defaults_columns: + implicit_return_defaults = set(stmt.table.c) + else: + implicit_return_defaults = set(stmt._return_defaults_columns) + + postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid + + return ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + ) + + +def _warn_pk_with_no_anticipated_value(c): + msg = ( + "Column '%s.%s' is marked as a member of the " + "primary key for table '%s', " + "but has no Python-side or server-side default generator indicated, " + "nor does it indicate 'autoincrement=True' or 'nullable=True', " + "and no explicit value is passed. " + "Primary key columns typically may not store NULL." + % (c.table.fullname, c.name, c.table.fullname) + ) + if len(c.table.primary_key) > 1: + msg += ( + " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be " + "indicated explicitly for composite (e.g. multicolumn) primary " + "keys if AUTO_INCREMENT/SERIAL/IDENTITY " + "behavior is expected for one of the columns in the primary key. " + "CREATE TABLE statements are impacted by this change as well on " + "most backends." + ) + util.warn(msg) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py new file mode 100644 index 0000000..e608052 --- /dev/null +++ b/lib/sqlalchemy/sql/ddl.py @@ -0,0 +1,1341 @@ +# sql/ddl.py +# Copyright (C) 2009-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" +Provides the hierarchy of DDL-defining schema items as well as routines +to invoke them for a create/drop call. + +""" + +from . import roles +from .base import _bind_or_error +from .base import _generative +from .base import Executable +from .base import SchemaVisitor +from .elements import ClauseElement +from .. import exc +from .. import util +from ..util import topological + + +class _DDLCompiles(ClauseElement): + _hierarchy_supports_caching = False + """disable cache warnings for all _DDLCompiles subclasses. """ + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + + return dialect.ddl_compiler(dialect, self, **kw) + + def _compile_w_cache(self, *arg, **kw): + raise NotImplementedError() + + +class DDLElement(roles.DDLRole, Executable, _DDLCompiles): + """Base class for DDL expression constructs. + + This class is the base for the general purpose :class:`.DDL` class, + as well as the various create/drop clause constructs such as + :class:`.CreateTable`, :class:`.DropTable`, :class:`.AddConstraint`, + etc. + + :class:`.DDLElement` integrates closely with SQLAlchemy events, + introduced in :ref:`event_toplevel`. An instance of one is + itself an event receiving callable:: + + event.listen( + users, + 'after_create', + AddConstraint(constraint).execute_if(dialect='postgresql') + ) + + .. seealso:: + + :class:`.DDL` + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + :ref:`schema_ddl_sequences` + + """ + + _execution_options = Executable._execution_options.union( + {"autocommit": True} + ) + + target = None + on = None + dialect = None + callable_ = None + + def _execute_on_connection( + self, connection, multiparams, params, execution_options + ): + return connection._execute_ddl( + self, multiparams, params, execution_options + ) + + @util.deprecated_20( + ":meth:`.DDLElement.execute`", + alternative="All statement execution in SQLAlchemy 2.0 is performed " + "by the :meth:`_engine.Connection.execute` method of " + ":class:`_engine.Connection`, " + "or in the ORM by the :meth:`.Session.execute` method of " + ":class:`.Session`.", + ) + def execute(self, bind=None, target=None): + """Execute this DDL immediately. + + Executes the DDL statement in isolation using the supplied + :class:`.Connectable` or + :class:`.Connectable` assigned to the ``.bind`` + property, if not supplied. If the DDL has a conditional ``on`` + criteria, it will be invoked with None as the event. + + :param bind: + Optional, an ``Engine`` or ``Connection``. If not supplied, a valid + :class:`.Connectable` must be present in the + ``.bind`` property. + + :param target: + Optional, defaults to None. The target :class:`_schema.SchemaItem` + for the execute call. This is equivalent to passing the + :class:`_schema.SchemaItem` to the :meth:`.DDLElement.against` + method and then invoking :meth:`_schema.DDLElement.execute` + upon the resulting :class:`_schema.DDLElement` object. See + :meth:`.DDLElement.against` for further detail. + + """ + + if bind is None: + bind = _bind_or_error(self) + + if self._should_execute(target, bind): + return bind.execute(self.against(target)) + else: + bind.engine.logger.info("DDL execution skipped, criteria not met.") + + @_generative + def against(self, target): + """Return a copy of this :class:`_schema.DDLElement` which will include + the given target. + + This essentially applies the given item to the ``.target`` attribute + of the returned :class:`_schema.DDLElement` object. This target + is then usable by event handlers and compilation routines in order to + provide services such as tokenization of a DDL string in terms of a + particular :class:`_schema.Table`. + + When a :class:`_schema.DDLElement` object is established as an event + handler for the :meth:`_events.DDLEvents.before_create` or + :meth:`_events.DDLEvents.after_create` events, and the event + then occurs for a given target such as a :class:`_schema.Constraint` + or :class:`_schema.Table`, that target is established with a copy + of the :class:`_schema.DDLElement` object using this method, which + then proceeds to the :meth:`_schema.DDLElement.execute` method + in order to invoke the actual DDL instruction. + + :param target: a :class:`_schema.SchemaItem` that will be the subject + of a DDL operation. + + :return: a copy of this :class:`_schema.DDLElement` with the + ``.target`` attribute assigned to the given + :class:`_schema.SchemaItem`. + + .. seealso:: + + :class:`_schema.DDL` - uses tokenization against the "target" when + processing the DDL string. + + """ + + self.target = target + + @_generative + def execute_if(self, dialect=None, callable_=None, state=None): + r"""Return a callable that will execute this + :class:`_ddl.DDLElement` conditionally within an event handler. + + Used to provide a wrapper for event listening:: + + event.listen( + metadata, + 'before_create', + DDL("my_ddl").execute_if(dialect='postgresql') + ) + + :param dialect: May be a string or tuple of strings. + If a string, it will be compared to the name of the + executing database dialect:: + + DDL('something').execute_if(dialect='postgresql') + + If a tuple, specifies multiple dialect names:: + + DDL('something').execute_if(dialect=('postgresql', 'mysql')) + + :param callable\_: A callable, which will be invoked with + four positional arguments as well as optional keyword + arguments: + + :ddl: + This DDL element. + + :target: + The :class:`_schema.Table` or :class:`_schema.MetaData` + object which is the + target of this event. May be None if the DDL is executed + explicitly. + + :bind: + The :class:`_engine.Connection` being used for DDL execution + + :tables: + Optional keyword argument - a list of Table objects which are to + be created/ dropped within a MetaData.create_all() or drop_all() + method call. + + :state: + Optional keyword argument - will be the ``state`` argument + passed to this function. + + :checkfirst: + Keyword argument, will be True if the 'checkfirst' flag was + set during the call to ``create()``, ``create_all()``, + ``drop()``, ``drop_all()``. + + If the callable returns a True value, the DDL statement will be + executed. + + :param state: any value which will be passed to the callable\_ + as the ``state`` keyword argument. + + .. seealso:: + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + """ + self.dialect = dialect + self.callable_ = callable_ + self.state = state + + def _should_execute(self, target, bind, **kw): + if isinstance(self.dialect, util.string_types): + if self.dialect != bind.engine.name: + return False + elif isinstance(self.dialect, (tuple, list, set)): + if bind.engine.name not in self.dialect: + return False + if self.callable_ is not None and not self.callable_( + self, target, bind, state=self.state, **kw + ): + return False + + return True + + def __call__(self, target, bind, **kw): + """Execute the DDL as a ddl_listener.""" + + if self._should_execute(target, bind, **kw): + return bind.execute(self.against(target)) + + def bind(self): + if self._bind: + return self._bind + + def _set_bind(self, bind): + self._bind = bind + + bind = property(bind, _set_bind) + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + +class DDL(DDLElement): + """A literal DDL statement. + + Specifies literal SQL DDL to be executed by the database. DDL objects + function as DDL event listeners, and can be subscribed to those events + listed in :class:`.DDLEvents`, using either :class:`_schema.Table` or + :class:`_schema.MetaData` objects as targets. + Basic templating support allows + a single DDL instance to handle repetitive tasks for multiple tables. + + Examples:: + + from sqlalchemy import event, DDL + + tbl = Table('users', metadata, Column('uid', Integer)) + event.listen(tbl, 'before_create', DDL('DROP TRIGGER users_trigger')) + + spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE') + event.listen(tbl, 'after_create', spow.execute_if(dialect='somedb')) + + drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') + connection.execute(drop_spow) + + When operating on Table events, the following ``statement`` + string substitutions are available:: + + %(table)s - the Table name, with any required quoting applied + %(schema)s - the schema name, with any required quoting applied + %(fullname)s - the Table name including schema, quoted if needed + + The DDL's "context", if any, will be combined with the standard + substitutions noted above. Keys present in the context will override + the standard substitutions. + + """ + + __visit_name__ = "ddl" + + @util.deprecated_params( + bind=( + "2.0", + "The :paramref:`_ddl.DDL.bind` argument is deprecated and " + "will be removed in SQLAlchemy 2.0.", + ), + ) + def __init__(self, statement, context=None, bind=None): + """Create a DDL statement. + + :param statement: + A string or unicode string to be executed. Statements will be + processed with Python's string formatting operator using + a fixed set of string substitutions, as well as additional + substitutions provided by the optional :paramref:`.DDL.context` + parameter. + + A literal '%' in a statement must be escaped as '%%'. + + SQL bind parameters are not available in DDL statements. + + :param context: + Optional dictionary, defaults to None. These values will be + available for use in string substitutions on the DDL statement. + + :param bind: + Optional. A :class:`.Connectable`, used by + default when ``execute()`` is invoked without a bind argument. + + + .. seealso:: + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + """ + + if not isinstance(statement, util.string_types): + raise exc.ArgumentError( + "Expected a string or unicode SQL statement, got '%r'" + % statement + ) + + self.statement = statement + self.context = context or {} + + self._bind = bind + + def __repr__(self): + return "<%s@%s; %s>" % ( + type(self).__name__, + id(self), + ", ".join( + [repr(self.statement)] + + [ + "%s=%r" % (key, getattr(self, key)) + for key in ("on", "context") + if getattr(self, key) + ] + ), + ) + + +class _CreateDropBase(DDLElement): + """Base class for DDL constructs that represent CREATE and DROP or + equivalents. + + The common theme of _CreateDropBase is a single + ``element`` attribute which refers to the element + to be created or dropped. + + """ + + @util.deprecated_params( + bind=( + "2.0", + "The :paramref:`_ddl.DDLElement.bind` argument is " + "deprecated and " + "will be removed in SQLAlchemy 2.0.", + ), + ) + def __init__( + self, + element, + bind=None, + if_exists=False, + if_not_exists=False, + _legacy_bind=None, + ): + self.element = element + if bind: + self.bind = bind + elif _legacy_bind: + self.bind = _legacy_bind + self.if_exists = if_exists + self.if_not_exists = if_not_exists + + @property + def stringify_dialect(self): + return self.element.create_drop_stringify_dialect + + def _create_rule_disable(self, compiler): + """Allow disable of _create_rule using a callable. + + Pass to _create_rule using + util.portable_instancemethod(self._create_rule_disable) + to retain serializability. + + """ + return False + + +class CreateSchema(_CreateDropBase): + """Represent a CREATE SCHEMA statement. + + The argument here is the string name of the schema. + + """ + + __visit_name__ = "create_schema" + + def __init__(self, name, quote=None, **kw): + """Create a new :class:`.CreateSchema` construct.""" + + self.quote = quote + super(CreateSchema, self).__init__(name, **kw) + + +class DropSchema(_CreateDropBase): + """Represent a DROP SCHEMA statement. + + The argument here is the string name of the schema. + + """ + + __visit_name__ = "drop_schema" + + def __init__(self, name, quote=None, cascade=False, **kw): + """Create a new :class:`.DropSchema` construct.""" + + self.quote = quote + self.cascade = cascade + super(DropSchema, self).__init__(name, **kw) + + +class CreateTable(_CreateDropBase): + """Represent a CREATE TABLE statement.""" + + __visit_name__ = "create_table" + + @util.deprecated_params( + bind=( + "2.0", + "The :paramref:`_ddl.CreateTable.bind` argument is deprecated and " + "will be removed in SQLAlchemy 2.0.", + ), + ) + def __init__( + self, + element, + bind=None, + include_foreign_key_constraints=None, + if_not_exists=False, + ): + """Create a :class:`.CreateTable` construct. + + :param element: a :class:`_schema.Table` that's the subject + of the CREATE + :param on: See the description for 'on' in :class:`.DDL`. + :param bind: See the description for 'bind' in :class:`.DDL`. + :param include_foreign_key_constraints: optional sequence of + :class:`_schema.ForeignKeyConstraint` objects that will be included + inline within the CREATE construct; if omitted, all foreign key + constraints that do not specify use_alter=True are included. + + .. versionadded:: 1.0.0 + + :param if_not_exists: if True, an IF NOT EXISTS operator will be + applied to the construct. + + .. versionadded:: 1.4.0b2 + + """ + super(CreateTable, self).__init__( + element, _legacy_bind=bind, if_not_exists=if_not_exists + ) + self.columns = [CreateColumn(column) for column in element.columns] + self.include_foreign_key_constraints = include_foreign_key_constraints + + +class _DropView(_CreateDropBase): + """Semi-public 'DROP VIEW' construct. + + Used by the test suite for dialect-agnostic drops of views. + This object will eventually be part of a public "view" API. + + """ + + __visit_name__ = "drop_view" + + +class CreateColumn(_DDLCompiles): + """Represent a :class:`_schema.Column` + as rendered in a CREATE TABLE statement, + via the :class:`.CreateTable` construct. + + This is provided to support custom column DDL within the generation + of CREATE TABLE statements, by using the + compiler extension documented in :ref:`sqlalchemy.ext.compiler_toplevel` + to extend :class:`.CreateColumn`. + + Typical integration is to examine the incoming :class:`_schema.Column` + object, and to redirect compilation if a particular flag or condition + is found:: + + from sqlalchemy import schema + from sqlalchemy.ext.compiler import compiles + + @compiles(schema.CreateColumn) + def compile(element, compiler, **kw): + column = element.element + + if "special" not in column.info: + return compiler.visit_create_column(element, **kw) + + text = "%s SPECIAL DIRECTIVE %s" % ( + column.name, + compiler.type_compiler.process(column.type) + ) + default = compiler.get_column_default_string(column) + if default is not None: + text += " DEFAULT " + default + + if not column.nullable: + text += " NOT NULL" + + if column.constraints: + text += " ".join( + compiler.process(const) + for const in column.constraints) + return text + + The above construct can be applied to a :class:`_schema.Table` + as follows:: + + from sqlalchemy import Table, Metadata, Column, Integer, String + from sqlalchemy import schema + + metadata = MetaData() + + table = Table('mytable', MetaData(), + Column('x', Integer, info={"special":True}, primary_key=True), + Column('y', String(50)), + Column('z', String(20), info={"special":True}) + ) + + metadata.create_all(conn) + + Above, the directives we've added to the :attr:`_schema.Column.info` + collection + will be detected by our custom compilation scheme:: + + CREATE TABLE mytable ( + x SPECIAL DIRECTIVE INTEGER NOT NULL, + y VARCHAR(50), + z SPECIAL DIRECTIVE VARCHAR(20), + PRIMARY KEY (x) + ) + + The :class:`.CreateColumn` construct can also be used to skip certain + columns when producing a ``CREATE TABLE``. This is accomplished by + creating a compilation rule that conditionally returns ``None``. + This is essentially how to produce the same effect as using the + ``system=True`` argument on :class:`_schema.Column`, which marks a column + as an implicitly-present "system" column. + + For example, suppose we wish to produce a :class:`_schema.Table` + which skips + rendering of the PostgreSQL ``xmin`` column against the PostgreSQL + backend, but on other backends does render it, in anticipation of a + triggered rule. A conditional compilation rule could skip this name only + on PostgreSQL:: + + from sqlalchemy.schema import CreateColumn + + @compiles(CreateColumn, "postgresql") + def skip_xmin(element, compiler, **kw): + if element.element.name == 'xmin': + return None + else: + return compiler.visit_create_column(element, **kw) + + + my_table = Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('xmin', Integer) + ) + + Above, a :class:`.CreateTable` construct will generate a ``CREATE TABLE`` + which only includes the ``id`` column in the string; the ``xmin`` column + will be omitted, but only against the PostgreSQL backend. + + """ + + __visit_name__ = "create_column" + + def __init__(self, element): + self.element = element + + +class DropTable(_CreateDropBase): + """Represent a DROP TABLE statement.""" + + __visit_name__ = "drop_table" + + @util.deprecated_params( + bind=( + "2.0", + "The :paramref:`_ddl.DropTable.bind` argument is " + "deprecated and " + "will be removed in SQLAlchemy 2.0.", + ), + ) + def __init__(self, element, bind=None, if_exists=False): + """Create a :class:`.DropTable` construct. + + :param element: a :class:`_schema.Table` that's the subject + of the DROP. + :param on: See the description for 'on' in :class:`.DDL`. + :param bind: See the description for 'bind' in :class:`.DDL`. + :param if_exists: if True, an IF EXISTS operator will be applied to the + construct. + + .. versionadded:: 1.4.0b2 + + """ + super(DropTable, self).__init__( + element, _legacy_bind=bind, if_exists=if_exists + ) + + +class CreateSequence(_CreateDropBase): + """Represent a CREATE SEQUENCE statement.""" + + __visit_name__ = "create_sequence" + + +class DropSequence(_CreateDropBase): + """Represent a DROP SEQUENCE statement.""" + + __visit_name__ = "drop_sequence" + + +class CreateIndex(_CreateDropBase): + """Represent a CREATE INDEX statement.""" + + __visit_name__ = "create_index" + + @util.deprecated_params( + bind=( + "2.0", + "The :paramref:`_ddl.CreateIndex.bind` argument is " + "deprecated and " + "will be removed in SQLAlchemy 2.0.", + ), + ) + def __init__(self, element, bind=None, if_not_exists=False): + """Create a :class:`.Createindex` construct. + + :param element: a :class:`_schema.Index` that's the subject + of the CREATE. + :param on: See the description for 'on' in :class:`.DDL`. + :param bind: See the description for 'bind' in :class:`.DDL`. + :param if_not_exists: if True, an IF NOT EXISTS operator will be + applied to the construct. + + .. versionadded:: 1.4.0b2 + + """ + super(CreateIndex, self).__init__( + element, _legacy_bind=bind, if_not_exists=if_not_exists + ) + + +class DropIndex(_CreateDropBase): + """Represent a DROP INDEX statement.""" + + __visit_name__ = "drop_index" + + @util.deprecated_params( + bind=( + "2.0", + "The :paramref:`_ddl.DropIndex.bind` argument is " + "deprecated and " + "will be removed in SQLAlchemy 2.0.", + ), + ) + def __init__(self, element, bind=None, if_exists=False): + """Create a :class:`.DropIndex` construct. + + :param element: a :class:`_schema.Index` that's the subject + of the DROP. + :param on: See the description for 'on' in :class:`.DDL`. + :param bind: See the description for 'bind' in :class:`.DDL`. + :param if_exists: if True, an IF EXISTS operator will be applied to the + construct. + + .. versionadded:: 1.4.0b2 + + """ + super(DropIndex, self).__init__( + element, _legacy_bind=bind, if_exists=if_exists + ) + + +class AddConstraint(_CreateDropBase): + """Represent an ALTER TABLE ADD CONSTRAINT statement.""" + + __visit_name__ = "add_constraint" + + def __init__(self, element, *args, **kw): + super(AddConstraint, self).__init__(element, *args, **kw) + element._create_rule = util.portable_instancemethod( + self._create_rule_disable + ) + + +class DropConstraint(_CreateDropBase): + """Represent an ALTER TABLE DROP CONSTRAINT statement.""" + + __visit_name__ = "drop_constraint" + + def __init__(self, element, cascade=False, **kw): + self.cascade = cascade + super(DropConstraint, self).__init__(element, **kw) + element._create_rule = util.portable_instancemethod( + self._create_rule_disable + ) + + +class SetTableComment(_CreateDropBase): + """Represent a COMMENT ON TABLE IS statement.""" + + __visit_name__ = "set_table_comment" + + +class DropTableComment(_CreateDropBase): + """Represent a COMMENT ON TABLE '' statement. + + Note this varies a lot across database backends. + + """ + + __visit_name__ = "drop_table_comment" + + +class SetColumnComment(_CreateDropBase): + """Represent a COMMENT ON COLUMN IS statement.""" + + __visit_name__ = "set_column_comment" + + +class DropColumnComment(_CreateDropBase): + """Represent a COMMENT ON COLUMN IS NULL statement.""" + + __visit_name__ = "drop_column_comment" + + +class DDLBase(SchemaVisitor): + def __init__(self, connection): + self.connection = connection + + +class SchemaGenerator(DDLBase): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): + super(SchemaGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + self.memo = {} + + def _can_create_table(self, table): + self.dialect.validate_identifier(table.name) + effective_schema = self.connection.schema_for_object(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or not self.dialect.has_table( + self.connection, table.name, schema=effective_schema + ) + + def _can_create_index(self, index): + effective_schema = self.connection.schema_for_object(index.table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or not self.dialect.has_index( + self.connection, + index.table.name, + index.name, + schema=effective_schema, + ) + + def _can_create_sequence(self, sequence): + effective_schema = self.connection.schema_for_object(sequence) + + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or not self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema + ) + ) + ) + + def visit_metadata(self, metadata): + if self.tables is not None: + tables = self.tables + else: + tables = list(metadata.tables.values()) + + collection = sort_tables_and_constraints( + [t for t in tables if self._can_create_table(t)] + ) + + seq_coll = [ + s + for s in metadata._sequences.values() + if s.column is None and self._can_create_sequence(s) + ] + + event_collection = [t for (t, fks) in collection if t is not None] + metadata.dispatch.before_create( + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) + + for seq in seq_coll: + self.traverse_single(seq, create_ok=True) + + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, + create_ok=True, + include_foreign_key_constraints=fkcs, + _is_metadata_operation=True, + ) + else: + for fkc in fkcs: + self.traverse_single(fkc) + + metadata.dispatch.after_create( + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) + + def visit_table( + self, + table, + create_ok=False, + include_foreign_key_constraints=None, + _is_metadata_operation=False, + ): + if not create_ok and not self._can_create_table(table): + return + + table.dispatch.before_create( + table, + self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self, + _is_metadata_operation=_is_metadata_operation, + ) + + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + if not self.dialect.supports_alter: + # e.g., don't omit any foreign key constraints + include_foreign_key_constraints = None + + self.connection.execute( + # fmt: off + CreateTable( + table, + include_foreign_key_constraints= # noqa + include_foreign_key_constraints, # noqa + ) + # fmt: on + ) + + if hasattr(table, "indexes"): + for index in table.indexes: + self.traverse_single(index, create_ok=True) + + if self.dialect.supports_comments and not self.dialect.inline_comments: + if table.comment is not None: + self.connection.execute(SetTableComment(table)) + + for column in table.columns: + if column.comment is not None: + self.connection.execute(SetColumnComment(column)) + + table.dispatch.after_create( + table, + self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self, + _is_metadata_operation=_is_metadata_operation, + ) + + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + self.connection.execute(AddConstraint(constraint)) + + def visit_sequence(self, sequence, create_ok=False): + if not create_ok and not self._can_create_sequence(sequence): + return + self.connection.execute(CreateSequence(sequence)) + + def visit_index(self, index, create_ok=False): + if not create_ok and not self._can_create_index(index): + return + self.connection.execute(CreateIndex(index)) + + +class SchemaDropper(DDLBase): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): + super(SchemaDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + self.memo = {} + + def visit_metadata(self, metadata): + if self.tables is not None: + tables = self.tables + else: + tables = list(metadata.tables.values()) + + try: + unsorted_tables = [t for t in tables if self._can_drop_table(t)] + collection = list( + reversed( + sort_tables_and_constraints( + unsorted_tables, + filter_fn=lambda constraint: False + if not self.dialect.supports_alter + or constraint.name is None + else None, + ) + ) + ) + except exc.CircularDependencyError as err2: + if not self.dialect.supports_alter: + util.warn( + "Can't sort tables for DROP; an " + "unresolvable foreign key " + "dependency exists between tables: %s; and backend does " + "not support ALTER. To restore at least a partial sort, " + "apply use_alter=True to ForeignKey and " + "ForeignKeyConstraint " + "objects involved in the cycle to mark these as known " + "cycles that will be ignored." + % (", ".join(sorted([t.fullname for t in err2.cycles]))) + ) + collection = [(t, ()) for t in unsorted_tables] + else: + util.raise_( + exc.CircularDependencyError( + err2.args[0], + err2.cycles, + err2.edges, + msg="Can't sort tables for DROP; an " + "unresolvable foreign key " + "dependency exists between tables: %s. Please ensure " + "that the ForeignKey and ForeignKeyConstraint objects " + "involved in the cycle have " + "names so that they can be dropped using " + "DROP CONSTRAINT." + % ( + ", ".join( + sorted([t.fullname for t in err2.cycles]) + ) + ), + ), + from_=err2, + ) + + seq_coll = [ + s + for s in metadata._sequences.values() + if self._can_drop_sequence(s) + ] + + event_collection = [t for (t, fks) in collection if t is not None] + + metadata.dispatch.before_drop( + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) + + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, + drop_ok=True, + _is_metadata_operation=True, + _ignore_sequences=seq_coll, + ) + else: + for fkc in fkcs: + self.traverse_single(fkc) + + for seq in seq_coll: + self.traverse_single(seq, drop_ok=seq.column is None) + + metadata.dispatch.after_drop( + metadata, + self.connection, + tables=event_collection, + checkfirst=self.checkfirst, + _ddl_runner=self, + ) + + def _can_drop_table(self, table): + self.dialect.validate_identifier(table.name) + effective_schema = self.connection.schema_for_object(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or self.dialect.has_table( + self.connection, table.name, schema=effective_schema + ) + + def _can_drop_index(self, index): + effective_schema = self.connection.schema_for_object(index.table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or self.dialect.has_index( + self.connection, + index.table.name, + index.name, + schema=effective_schema, + ) + + def _can_drop_sequence(self, sequence): + effective_schema = self.connection.schema_for_object(sequence) + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema + ) + ) + ) + + def visit_index(self, index, drop_ok=False): + if not drop_ok and not self._can_drop_index(index): + return + + self.connection.execute(DropIndex(index)) + + def visit_table( + self, + table, + drop_ok=False, + _is_metadata_operation=False, + _ignore_sequences=(), + ): + if not drop_ok and not self._can_drop_table(table): + return + + table.dispatch.before_drop( + table, + self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self, + _is_metadata_operation=_is_metadata_operation, + ) + + self.connection.execute(DropTable(table)) + + # traverse client side defaults which may refer to server-side + # sequences. noting that some of these client side defaults may also be + # set up as server side defaults (see https://docs.sqlalchemy.org/en/ + # latest/core/defaults.html#associating-a-sequence-as-the-server-side- + # default), so have to be dropped after the table is dropped. + for column in table.columns: + if ( + column.default is not None + and column.default not in _ignore_sequences + ): + self.traverse_single(column.default) + + table.dispatch.after_drop( + table, + self.connection, + checkfirst=self.checkfirst, + _ddl_runner=self, + _is_metadata_operation=_is_metadata_operation, + ) + + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + self.connection.execute(DropConstraint(constraint)) + + def visit_sequence(self, sequence, drop_ok=False): + + if not drop_ok and not self._can_drop_sequence(sequence): + return + self.connection.execute(DropSequence(sequence)) + + +def sort_tables( + tables, + skip_fn=None, + extra_dependencies=None, +): + """Sort a collection of :class:`_schema.Table` objects based on + dependency. + + This is a dependency-ordered sort which will emit :class:`_schema.Table` + objects such that they will follow their dependent :class:`_schema.Table` + objects. + Tables are dependent on another based on the presence of + :class:`_schema.ForeignKeyConstraint` + objects as well as explicit dependencies + added by :meth:`_schema.Table.add_is_dependent_on`. + + .. warning:: + + The :func:`._schema.sort_tables` function cannot by itself + accommodate automatic resolution of dependency cycles between + tables, which are usually caused by mutually dependent foreign key + constraints. When these cycles are detected, the foreign keys + of these tables are omitted from consideration in the sort. + A warning is emitted when this condition occurs, which will be an + exception raise in a future release. Tables which are not part + of the cycle will still be returned in dependency order. + + To resolve these cycles, the + :paramref:`_schema.ForeignKeyConstraint.use_alter` parameter may be + applied to those constraints which create a cycle. Alternatively, + the :func:`_schema.sort_tables_and_constraints` function will + automatically return foreign key constraints in a separate + collection when cycles are detected so that they may be applied + to a schema separately. + + .. versionchanged:: 1.3.17 - a warning is emitted when + :func:`_schema.sort_tables` cannot perform a proper sort due to + cyclical dependencies. This will be an exception in a future + release. Additionally, the sort will continue to return + other tables not involved in the cycle in dependency order + which was not the case previously. + + :param tables: a sequence of :class:`_schema.Table` objects. + + :param skip_fn: optional callable which will be passed a + :class:`_schema.ForeignKey` object; if it returns True, this + constraint will not be considered as a dependency. Note this is + **different** from the same parameter in + :func:`.sort_tables_and_constraints`, which is + instead passed the owning :class:`_schema.ForeignKeyConstraint` object. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. seealso:: + + :func:`.sort_tables_and_constraints` + + :attr:`_schema.MetaData.sorted_tables` - uses this function to sort + + + """ + + if skip_fn is not None: + + def _skip_fn(fkc): + for fk in fkc.elements: + if skip_fn(fk): + return True + else: + return None + + else: + _skip_fn = None + + return [ + t + for (t, fkcs) in sort_tables_and_constraints( + tables, + filter_fn=_skip_fn, + extra_dependencies=extra_dependencies, + _warn_for_cycles=True, + ) + if t is not None + ] + + +def sort_tables_and_constraints( + tables, filter_fn=None, extra_dependencies=None, _warn_for_cycles=False +): + """Sort a collection of :class:`_schema.Table` / + :class:`_schema.ForeignKeyConstraint` + objects. + + This is a dependency-ordered sort which will emit tuples of + ``(Table, [ForeignKeyConstraint, ...])`` such that each + :class:`_schema.Table` follows its dependent :class:`_schema.Table` + objects. + Remaining :class:`_schema.ForeignKeyConstraint` + objects that are separate due to + dependency rules not satisfied by the sort are emitted afterwards + as ``(None, [ForeignKeyConstraint ...])``. + + Tables are dependent on another based on the presence of + :class:`_schema.ForeignKeyConstraint` objects, explicit dependencies + added by :meth:`_schema.Table.add_is_dependent_on`, + as well as dependencies + stated here using the :paramref:`~.sort_tables_and_constraints.skip_fn` + and/or :paramref:`~.sort_tables_and_constraints.extra_dependencies` + parameters. + + :param tables: a sequence of :class:`_schema.Table` objects. + + :param filter_fn: optional callable which will be passed a + :class:`_schema.ForeignKeyConstraint` object, + and returns a value based on + whether this constraint should definitely be included or excluded as + an inline constraint, or neither. If it returns False, the constraint + will definitely be included as a dependency that cannot be subject + to ALTER; if True, it will **only** be included as an ALTER result at + the end. Returning None means the constraint is included in the + table-based result unless it is detected as part of a dependency cycle. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :func:`.sort_tables` + + + """ + + fixed_dependencies = set() + mutable_dependencies = set() + + if extra_dependencies is not None: + fixed_dependencies.update(extra_dependencies) + + remaining_fkcs = set() + for table in tables: + for fkc in table.foreign_key_constraints: + if fkc.use_alter is True: + remaining_fkcs.add(fkc) + continue + + if filter_fn: + filtered = filter_fn(fkc) + + if filtered is True: + remaining_fkcs.add(fkc) + continue + + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.add((dependent_on, table)) + + fixed_dependencies.update( + (parent, table) for parent in table._extra_dependencies + ) + + try: + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), + tables, + ) + ) + except exc.CircularDependencyError as err: + if _warn_for_cycles: + util.warn( + "Cannot correctly sort tables; there are unresolvable cycles " + 'between tables "%s", which is usually caused by mutually ' + "dependent foreign key constraints. Foreign key constraints " + "involving these tables will not be considered; this warning " + "may raise an error in a future release." + % (", ".join(sorted(t.fullname for t in err.cycles)),) + ) + for edge in err.edges: + if edge in mutable_dependencies: + table = edge[1] + if table not in err.cycles: + continue + can_remove = [ + fkc + for fkc in table.foreign_key_constraints + if filter_fn is None or filter_fn(fkc) is not False + ] + remaining_fkcs.update(can_remove) + for fkc in can_remove: + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.discard((dependent_on, table)) + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), + tables, + ) + ) + + return [ + (table, table.foreign_key_constraints.difference(remaining_fkcs)) + for table in candidate_sort + ] + [(None, list(remaining_fkcs))] diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py new file mode 100644 index 0000000..70586c6 --- /dev/null +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -0,0 +1,360 @@ +# sql/default_comparator.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Default implementation of SQL comparison operations. +""" + + +from . import coercions +from . import operators +from . import roles +from . import type_api +from .elements import and_ +from .elements import BinaryExpression +from .elements import ClauseList +from .elements import collate +from .elements import CollectionAggregate +from .elements import False_ +from .elements import Null +from .elements import or_ +from .elements import True_ +from .elements import UnaryExpression +from .. import exc +from .. import util + + +def _boolean_compare( + expr, + op, + obj, + negate=None, + reverse=False, + _python_is_types=(util.NoneType, bool), + _any_all_expr=False, + result_type=None, + **kwargs +): + + if result_type is None: + result_type = type_api.BOOLEANTYPE + + if isinstance(obj, _python_is_types + (Null, True_, False_)): + # allow x ==/!= True/False to be treated as a literal. + # this comes out to "== / != true/false" or "1/0" if those + # constants aren't supported and works on all platforms + if op in (operators.eq, operators.ne) and isinstance( + obj, (bool, True_, False_) + ): + return BinaryExpression( + expr, + coercions.expect(roles.ConstExprRole, obj), + op, + type_=result_type, + negate=negate, + modifiers=kwargs, + ) + elif op in ( + operators.is_distinct_from, + operators.is_not_distinct_from, + ): + return BinaryExpression( + expr, + coercions.expect(roles.ConstExprRole, obj), + op, + type_=result_type, + negate=negate, + modifiers=kwargs, + ) + elif _any_all_expr: + obj = coercions.expect( + roles.ConstExprRole, element=obj, operator=op, expr=expr + ) + else: + # all other None uses IS, IS NOT + if op in (operators.eq, operators.is_): + return BinaryExpression( + expr, + coercions.expect(roles.ConstExprRole, obj), + operators.is_, + negate=operators.is_not, + type_=result_type, + ) + elif op in (operators.ne, operators.is_not): + return BinaryExpression( + expr, + coercions.expect(roles.ConstExprRole, obj), + operators.is_not, + negate=operators.is_, + type_=result_type, + ) + else: + raise exc.ArgumentError( + "Only '=', '!=', 'is_()', 'is_not()', " + "'is_distinct_from()', 'is_not_distinct_from()' " + "operators can be used with None/True/False" + ) + else: + obj = coercions.expect( + roles.BinaryElementRole, element=obj, operator=op, expr=expr + ) + + if reverse: + return BinaryExpression( + obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs + ) + else: + return BinaryExpression( + expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs + ) + + +def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw): + if result_type is None: + if op.return_type: + result_type = op.return_type + elif op.is_comparison: + result_type = type_api.BOOLEANTYPE + + return _binary_operate( + expr, op, obj, reverse=reverse, result_type=result_type, **kw + ) + + +def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw): + obj = coercions.expect( + roles.BinaryElementRole, obj, expr=expr, operator=op + ) + + if reverse: + left, right = obj, expr + else: + left, right = expr, obj + + if result_type is None: + op, result_type = left.comparator._adapt_expression( + op, right.comparator + ) + + return BinaryExpression(left, right, op, type_=result_type, modifiers=kw) + + +def _conjunction_operate(expr, op, other, **kw): + if op is operators.and_: + return and_(expr, other) + elif op is operators.or_: + return or_(expr, other) + else: + raise NotImplementedError() + + +def _scalar(expr, op, fn, **kw): + return fn(expr) + + +def _in_impl(expr, op, seq_or_selectable, negate_op, **kw): + seq_or_selectable = coercions.expect( + roles.InElementRole, seq_or_selectable, expr=expr, operator=op + ) + if "in_ops" in seq_or_selectable._annotations: + op, negate_op = seq_or_selectable._annotations["in_ops"] + + return _boolean_compare( + expr, op, seq_or_selectable, negate=negate_op, **kw + ) + + +def _getitem_impl(expr, op, other, **kw): + if ( + isinstance(expr.type, type_api.INDEXABLE) + or isinstance(expr.type, type_api.TypeDecorator) + and isinstance(expr.type.impl, type_api.INDEXABLE) + ): + other = coercions.expect( + roles.BinaryElementRole, other, expr=expr, operator=op + ) + return _binary_operate(expr, op, other, **kw) + else: + _unsupported_impl(expr, op, other, **kw) + + +def _unsupported_impl(expr, op, *arg, **kw): + raise NotImplementedError( + "Operator '%s' is not supported on " "this expression" % op.__name__ + ) + + +def _inv_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.__inv__`.""" + + # undocumented element currently used by the ORM for + # relationship.contains() + if hasattr(expr, "negation_clause"): + return expr.negation_clause + else: + return expr._negate() + + +def _neg_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.__neg__`.""" + return UnaryExpression(expr, operator=operators.neg, type_=expr.type) + + +def _match_impl(expr, op, other, **kw): + """See :meth:`.ColumnOperators.match`.""" + + return _boolean_compare( + expr, + operators.match_op, + coercions.expect( + roles.BinaryElementRole, + other, + expr=expr, + operator=operators.match_op, + ), + result_type=type_api.MATCHTYPE, + negate=operators.not_match_op + if op is operators.match_op + else operators.match_op, + **kw + ) + + +def _distinct_impl(expr, op, **kw): + """See :meth:`.ColumnOperators.distinct`.""" + return UnaryExpression( + expr, operator=operators.distinct_op, type_=expr.type + ) + + +def _between_impl(expr, op, cleft, cright, **kw): + """See :meth:`.ColumnOperators.between`.""" + return BinaryExpression( + expr, + ClauseList( + coercions.expect( + roles.BinaryElementRole, + cleft, + expr=expr, + operator=operators.and_, + ), + coercions.expect( + roles.BinaryElementRole, + cright, + expr=expr, + operator=operators.and_, + ), + operator=operators.and_, + group=False, + group_contents=False, + ), + op, + negate=operators.not_between_op + if op is operators.between_op + else operators.between_op, + modifiers=kw, + ) + + +def _collate_impl(expr, op, other, **kw): + return collate(expr, other) + + +def _regexp_match_impl(expr, op, pattern, flags, **kw): + if flags is not None: + flags = coercions.expect( + roles.BinaryElementRole, + flags, + expr=expr, + operator=operators.regexp_replace_op, + ) + return _boolean_compare( + expr, + op, + pattern, + flags=flags, + negate=operators.not_regexp_match_op + if op is operators.regexp_match_op + else operators.regexp_match_op, + **kw + ) + + +def _regexp_replace_impl(expr, op, pattern, replacement, flags, **kw): + replacement = coercions.expect( + roles.BinaryElementRole, + replacement, + expr=expr, + operator=operators.regexp_replace_op, + ) + if flags is not None: + flags = coercions.expect( + roles.BinaryElementRole, + flags, + expr=expr, + operator=operators.regexp_replace_op, + ) + return _binary_operate( + expr, op, pattern, replacement=replacement, flags=flags, **kw + ) + + +# a mapping of operators with the method they use, along with +# their negated operator for comparison operators +operator_lookup = { + "and_": (_conjunction_operate,), + "or_": (_conjunction_operate,), + "inv": (_inv_impl,), + "add": (_binary_operate,), + "mul": (_binary_operate,), + "sub": (_binary_operate,), + "div": (_binary_operate,), + "mod": (_binary_operate,), + "truediv": (_binary_operate,), + "custom_op": (_custom_op_operate,), + "json_path_getitem_op": (_binary_operate,), + "json_getitem_op": (_binary_operate,), + "concat_op": (_binary_operate,), + "any_op": (_scalar, CollectionAggregate._create_any), + "all_op": (_scalar, CollectionAggregate._create_all), + "lt": (_boolean_compare, operators.ge), + "le": (_boolean_compare, operators.gt), + "ne": (_boolean_compare, operators.eq), + "gt": (_boolean_compare, operators.le), + "ge": (_boolean_compare, operators.lt), + "eq": (_boolean_compare, operators.ne), + "is_distinct_from": (_boolean_compare, operators.is_not_distinct_from), + "is_not_distinct_from": (_boolean_compare, operators.is_distinct_from), + "like_op": (_boolean_compare, operators.not_like_op), + "ilike_op": (_boolean_compare, operators.not_ilike_op), + "not_like_op": (_boolean_compare, operators.like_op), + "not_ilike_op": (_boolean_compare, operators.ilike_op), + "contains_op": (_boolean_compare, operators.not_contains_op), + "startswith_op": (_boolean_compare, operators.not_startswith_op), + "endswith_op": (_boolean_compare, operators.not_endswith_op), + "desc_op": (_scalar, UnaryExpression._create_desc), + "asc_op": (_scalar, UnaryExpression._create_asc), + "nulls_first_op": (_scalar, UnaryExpression._create_nulls_first), + "nulls_last_op": (_scalar, UnaryExpression._create_nulls_last), + "in_op": (_in_impl, operators.not_in_op), + "not_in_op": (_in_impl, operators.in_op), + "is_": (_boolean_compare, operators.is_), + "is_not": (_boolean_compare, operators.is_not), + "collate": (_collate_impl,), + "match_op": (_match_impl,), + "not_match_op": (_match_impl,), + "distinct_op": (_distinct_impl,), + "between_op": (_between_impl,), + "not_between_op": (_between_impl,), + "neg": (_neg_impl,), + "getitem": (_getitem_impl,), + "lshift": (_unsupported_impl,), + "rshift": (_unsupported_impl,), + "contains": (_unsupported_impl,), + "regexp_match_op": (_regexp_match_impl,), + "not_regexp_match_op": (_regexp_match_impl,), + "regexp_replace_op": (_regexp_replace_impl,), +} diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py new file mode 100644 index 0000000..07a4d7b --- /dev/null +++ b/lib/sqlalchemy/sql/dml.py @@ -0,0 +1,1514 @@ +# sql/dml.py +# Copyright (C) 2009-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" +Provide :class:`_expression.Insert`, :class:`_expression.Update` and +:class:`_expression.Delete`. + +""" +from sqlalchemy.types import NullType +from . import coercions +from . import roles +from . import util as sql_util +from .base import _entity_namespace_key +from .base import _exclusive_against +from .base import _from_objects +from .base import _generative +from .base import ColumnCollection +from .base import CompileState +from .base import DialectKWArgs +from .base import Executable +from .base import HasCompileState +from .elements import BooleanClauseList +from .elements import ClauseElement +from .elements import Null +from .selectable import HasCTE +from .selectable import HasPrefixes +from .selectable import ReturnsRows +from .visitors import InternalTraversal +from .. import exc +from .. import util +from ..util import collections_abc + + +class DMLState(CompileState): + _no_parameters = True + _dict_parameters = None + _multi_parameters = None + _ordered_values = None + _parameter_ordering = None + _has_multi_parameters = False + isupdate = False + isdelete = False + isinsert = False + + def __init__(self, statement, compiler, **kw): + raise NotImplementedError() + + @classmethod + def get_entity_description(cls, statement): + return {"name": statement.table.name, "table": statement.table} + + @classmethod + def get_returning_column_descriptions(cls, statement): + return [ + { + "name": c.key, + "type": c.type, + "expr": c, + } + for c in statement._all_selected_columns + ] + + @property + def dml_table(self): + return self.statement.table + + @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator): + return [ + ( + coercions.expect(roles.DMLColumnRole, k), + coercions.expect( + roles.ExpressionElementRole, + v, + type_=NullType(), + is_crud=True, + ), + ) + for k, v in kv_iterator + ] + + def _make_extra_froms(self, statement): + froms = [] + + all_tables = list(sql_util.tables_from_leftmost(statement.table)) + seen = {all_tables[0]} + + for crit in statement._where_criteria: + for item in _from_objects(crit): + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) + + froms.extend(all_tables[1:]) + return froms + + def _process_multi_values(self, statement): + if not statement._supports_multi_parameters: + raise exc.InvalidRequestError( + "%s construct does not support " + "multiple parameter sets." % statement.__visit_name__.upper() + ) + + for parameters in statement._multi_values: + multi_parameters = [ + { + c.key: value + for c, value in zip(statement.table.c, parameter_set) + } + if isinstance(parameter_set, collections_abc.Sequence) + else parameter_set + for parameter_set in parameters + ] + + if self._no_parameters: + self._no_parameters = False + self._has_multi_parameters = True + self._multi_parameters = multi_parameters + self._dict_parameters = self._multi_parameters[0] + elif not self._has_multi_parameters: + self._cant_mix_formats_error() + else: + self._multi_parameters.extend(multi_parameters) + + def _process_values(self, statement): + if self._no_parameters: + self._has_multi_parameters = False + self._dict_parameters = statement._values + self._no_parameters = False + elif self._has_multi_parameters: + self._cant_mix_formats_error() + + def _process_ordered_values(self, statement): + parameters = statement._ordered_values + + if self._no_parameters: + self._no_parameters = False + self._dict_parameters = dict(parameters) + self._ordered_values = parameters + self._parameter_ordering = [key for key, value in parameters] + elif self._has_multi_parameters: + self._cant_mix_formats_error() + else: + raise exc.InvalidRequestError( + "Can only invoke ordered_values() once, and not mixed " + "with any other values() call" + ) + + def _process_select_values(self, statement): + parameters = { + coercions.expect(roles.DMLColumnRole, name, as_key=True): Null() + for name in statement._select_names + } + + if self._no_parameters: + self._no_parameters = False + self._dict_parameters = parameters + else: + # this condition normally not reachable as the Insert + # does not allow this construction to occur + assert False, "This statement already has parameters" + + def _cant_mix_formats_error(self): + raise exc.InvalidRequestError( + "Can't mix single and multiple VALUES " + "formats in one INSERT statement; one style appends to a " + "list while the other replaces values, so the intent is " + "ambiguous." + ) + + +@CompileState.plugin_for("default", "insert") +class InsertDMLState(DMLState): + isinsert = True + + include_table_with_column_exprs = False + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isinsert = True + if statement._select_names: + self._process_select_values(statement) + if statement._values is not None: + self._process_values(statement) + if statement._multi_values: + self._process_multi_values(statement) + + @util.memoized_property + def _insert_col_keys(self): + # this is also done in crud.py -> _key_getters_for_crud_column + return [ + coercions.expect_as_key(roles.DMLColumnRole, col) + for col in self._dict_parameters + ] + + +@CompileState.plugin_for("default", "update") +class UpdateDMLState(DMLState): + isupdate = True + + include_table_with_column_exprs = False + + def __init__(self, statement, compiler, **kw): + self.statement = statement + self.isupdate = True + self._preserve_parameter_order = statement._preserve_parameter_order + if statement._ordered_values is not None: + self._process_ordered_values(statement) + elif statement._values is not None: + self._process_values(statement) + elif statement._multi_values: + self._process_multi_values(statement) + self._extra_froms = ef = self._make_extra_froms(statement) + self.is_multitable = mt = ef and self._dict_parameters + self.include_table_with_column_exprs = ( + mt and compiler.render_table_with_column_in_update_from + ) + + +@CompileState.plugin_for("default", "delete") +class DeleteDMLState(DMLState): + isdelete = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isdelete = True + self._extra_froms = self._make_extra_froms(statement) + + +class UpdateBase( + roles.DMLRole, + HasCTE, + HasCompileState, + DialectKWArgs, + HasPrefixes, + ReturnsRows, + Executable, + ClauseElement, +): + """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" + + __visit_name__ = "update_base" + + _execution_options = Executable._execution_options.union( + {"autocommit": True} + ) + _hints = util.immutabledict() + named_with_column = False + + _return_defaults = False + _return_defaults_columns = None + _returning = () + + is_dml = True + + @classmethod + def _constructor_20_deprecations(cls, fn_name, clsname, names): + + param_to_method_lookup = dict( + whereclause=( + "The :paramref:`%(func)s.whereclause` parameter " + "will be removed " + "in SQLAlchemy 2.0. Please refer to the " + ":meth:`%(classname)s.where` method." + ), + values=( + "The :paramref:`%(func)s.values` parameter will be removed " + "in SQLAlchemy 2.0. Please refer to the " + ":meth:`%(classname)s.values` method." + ), + bind=( + "The :paramref:`%(func)s.bind` parameter will be removed in " + "SQLAlchemy 2.0. Please use explicit connection execution." + ), + inline=( + "The :paramref:`%(func)s.inline` parameter will be " + "removed in " + "SQLAlchemy 2.0. Please use the " + ":meth:`%(classname)s.inline` method." + ), + prefixes=( + "The :paramref:`%(func)s.prefixes parameter will be " + "removed in " + "SQLAlchemy 2.0. Please use the " + ":meth:`%(classname)s.prefix_with` " + "method." + ), + return_defaults=( + "The :paramref:`%(func)s.return_defaults` parameter will be " + "removed in SQLAlchemy 2.0. Please use the " + ":meth:`%(classname)s.return_defaults` method." + ), + returning=( + "The :paramref:`%(func)s.returning` parameter will be " + "removed in SQLAlchemy 2.0. Please use the " + ":meth:`%(classname)s.returning`` method." + ), + preserve_parameter_order=( + "The :paramref:`%(func)s.preserve_parameter_order` parameter " + "will be removed in SQLAlchemy 2.0. Use the " + ":meth:`%(classname)s.ordered_values` method with a list " + "of tuples. " + ), + ) + + return util.deprecated_params( + **{ + name: ( + "2.0", + param_to_method_lookup[name] + % { + "func": "_expression.%s" % fn_name, + "classname": "_expression.%s" % clsname, + }, + ) + for name in names + } + ) + + def _generate_fromclause_column_proxies(self, fromclause): + fromclause._columns._populate_separate_keys( + col._make_proxy(fromclause) for col in self._returning + ) + + def params(self, *arg, **kw): + """Set the parameters for the statement. + + This method raises ``NotImplementedError`` on the base class, + and is overridden by :class:`.ValuesBase` to provide the + SET/VALUES clause of UPDATE and INSERT. + + """ + raise NotImplementedError( + "params() is not supported for INSERT/UPDATE/DELETE statements." + " To set the values for an INSERT or UPDATE statement, use" + " stmt.values(**parameters)." + ) + + @_generative + def with_dialect_options(self, **opt): + """Add dialect options to this INSERT/UPDATE/DELETE object. + + e.g.:: + + upd = table.update().dialect_options(mysql_limit=10) + + .. versionadded: 1.4 - this method supersedes the dialect options + associated with the constructor. + + + """ + self._validate_dialect_kwargs(opt) + + def _validate_dialect_kwargs_deprecated(self, dialect_kw): + util.warn_deprecated_20( + "Passing dialect keyword arguments directly to the " + "%s constructor is deprecated and will be removed in SQLAlchemy " + "2.0. Please use the ``with_dialect_options()`` method." + % (self.__class__.__name__) + ) + self._validate_dialect_kwargs(dialect_kw) + + def bind(self): + """Return a 'bind' linked to this :class:`.UpdateBase` + or a :class:`_schema.Table` associated with it. + + """ + return self._bind or self.table.bind + + def _set_bind(self, bind): + self._bind = bind + + bind = property(bind, _set_bind) + + @_generative + def returning(self, *cols): + r"""Add a :term:`RETURNING` or equivalent clause to this statement. + + e.g.: + + .. sourcecode:: pycon+sql + + >>> stmt = ( + ... table.update() + ... .where(table.c.data == "value") + ... .values(status="X") + ... .returning(table.c.server_flag, table.c.updated_timestamp) + ... ) + >>> print(stmt) + UPDATE some_table SET status=:status + WHERE some_table.data = :data_1 + RETURNING some_table.server_flag, some_table.updated_timestamp + + The method may be invoked multiple times to add new entries to the + list of expressions to be returned. + + .. versionadded:: 1.4.0b2 The method may be invoked multiple times to + add new entries to the list of expressions to be returned. + + The given collection of column expressions should be derived from the + table that is the target of the INSERT, UPDATE, or DELETE. While + :class:`_schema.Column` objects are typical, the elements can also be + expressions: + + .. sourcecode:: pycon+sql + + >>> stmt = table.insert().returning( + ... (table.c.first_name + " " + table.c.last_name).label("fullname") + ... ) + >>> print(stmt) + INSERT INTO some_table (first_name, last_name) + VALUES (:first_name, :last_name) + RETURNING some_table.first_name || :first_name_1 || some_table.last_name AS fullname + + Upon compilation, a RETURNING clause, or database equivalent, + will be rendered within the statement. For INSERT and UPDATE, + the values are the newly inserted/updated values. For DELETE, + the values are those of the rows which were deleted. + + Upon execution, the values of the columns to be returned are made + available via the result set and can be iterated using + :meth:`_engine.CursorResult.fetchone` and similar. + For DBAPIs which do not + natively support returning values (i.e. cx_oracle), SQLAlchemy will + approximate this behavior at the result level so that a reasonable + amount of behavioral neutrality is provided. + + Note that not all databases/DBAPIs + support RETURNING. For those backends with no support, + an exception is raised upon compilation and/or execution. + For those who do support it, the functionality across backends + varies greatly, including restrictions on executemany() + and other statements which return multiple rows. Please + read the documentation notes for the database in use in + order to determine the availability of RETURNING. + + .. seealso:: + + :meth:`.ValuesBase.return_defaults` - an alternative method tailored + towards efficient fetching of server-side defaults and triggers + for single-row INSERTs or UPDATEs. + + :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial` + + """ # noqa: E501 + if self._return_defaults: + raise exc.InvalidRequestError( + "return_defaults() is already configured on this statement" + ) + self._returning += tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) + + @property + def _all_selected_columns(self): + return self._returning + + @property + def exported_columns(self): + """Return the RETURNING columns as a column collection for this + statement. + + .. versionadded:: 1.4 + + """ + # TODO: no coverage here + return ColumnCollection( + (c.key, c) for c in self._all_selected_columns + ).as_immutable() + + @_generative + def with_hint(self, text, selectable=None, dialect_name="*"): + """Add a table hint for a single table to this + INSERT/UPDATE/DELETE statement. + + .. note:: + + :meth:`.UpdateBase.with_hint` currently applies only to + Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use + :meth:`.UpdateBase.prefix_with`. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the :class:`_schema.Table` that is the subject of this + statement, or optionally to that of the given + :class:`_schema.Table` passed as the ``selectable`` argument. + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add a hint + that only takes effect for SQL Server:: + + mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql") + + :param text: Text of the hint. + :param selectable: optional :class:`_schema.Table` that specifies + an element of the FROM clause within an UPDATE or DELETE + to be the subject of the hint - applies only to certain backends. + :param dialect_name: defaults to ``*``, if specified as the name + of a particular dialect, will apply these hints only when + that dialect is in use. + """ + if selectable is None: + selectable = self.table + + self._hints = self._hints.union({(selectable, dialect_name): text}) + + @property + def entity_description(self): + """Return a :term:`plugin-enabled` description of the table and/or + entity which this DML construct is operating against. + + This attribute is generally useful when using the ORM, as an + extended structure which includes information about mapped + entities is returned. The section :ref:`queryguide_inspection` + contains more background. + + For a Core statement, the structure returned by this accessor + is derived from the :attr:`.UpdateBase.table` attribute, and + refers to the :class:`.Table` being inserted, updated, or deleted:: + + >>> stmt = insert(user_table) + >>> stmt.entity_description + { + "name": "user_table", + "table": Table("user_table", ...) + } + + .. versionadded:: 1.4.33 + + .. seealso:: + + :attr:`.UpdateBase.returning_column_descriptions` + + :attr:`.Select.column_descriptions` - entity information for + a :func:`.select` construct + + :ref:`queryguide_inspection` - ORM background + + """ + meth = DMLState.get_plugin_class(self).get_entity_description + return meth(self) + + @property + def returning_column_descriptions(self): + """Return a :term:`plugin-enabled` description of the columns + which this DML construct is RETURNING against, in other words + the expressions established as part of :meth:`.UpdateBase.returning`. + + This attribute is generally useful when using the ORM, as an + extended structure which includes information about mapped + entities is returned. The section :ref:`queryguide_inspection` + contains more background. + + For a Core statement, the structure returned by this accessor is + derived from the same objects that are returned by the + :attr:`.UpdateBase.exported_columns` accessor:: + + >>> stmt = insert(user_table).returning(user_table.c.id, user_table.c.name) + >>> stmt.entity_description + [ + { + "name": "id", + "type": Integer, + "expr": Column("id", Integer(), table=, ...) + }, + { + "name": "name", + "type": String(), + "expr": Column("name", String(), table=, ...) + }, + ] + + .. versionadded:: 1.4.33 + + .. seealso:: + + :attr:`.UpdateBase.entity_description` + + :attr:`.Select.column_descriptions` - entity information for + a :func:`.select` construct + + :ref:`queryguide_inspection` - ORM background + + """ # noqa: E501 + meth = DMLState.get_plugin_class( + self + ).get_returning_column_descriptions + return meth(self) + + +class ValuesBase(UpdateBase): + """Supplies support for :meth:`.ValuesBase.values` to + INSERT and UPDATE constructs.""" + + __visit_name__ = "values_base" + + _supports_multi_parameters = False + _preserve_parameter_order = False + select = None + _post_values_clause = None + + _values = None + _multi_values = () + _ordered_values = None + _select_names = None + + _returning = () + + def __init__(self, table, values, prefixes): + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) + if values is not None: + self.values.non_generative(self, values) + if prefixes: + self._setup_prefixes(prefixes) + + @_generative + @_exclusive_against( + "_select_names", + "_ordered_values", + msgs={ + "_select_names": "This construct already inserts from a SELECT", + "_ordered_values": "This statement already has ordered " + "values present", + }, + ) + def values(self, *args, **kwargs): + r"""Specify a fixed VALUES clause for an INSERT statement, or the SET + clause for an UPDATE. + + Note that the :class:`_expression.Insert` and + :class:`_expression.Update` + constructs support + per-execution time formatting of the VALUES and/or SET clauses, + based on the arguments passed to :meth:`_engine.Connection.execute`. + However, the :meth:`.ValuesBase.values` method can be used to "fix" a + particular set of parameters into the statement. + + Multiple calls to :meth:`.ValuesBase.values` will produce a new + construct, each one with the parameter list modified to include + the new parameters sent. In the typical case of a single + dictionary of parameters, the newly passed keys will replace + the same keys in the previous construct. In the case of a list-based + "multiple values" construct, each new list of values is extended + onto the existing list of values. + + :param \**kwargs: key value pairs representing the string key + of a :class:`_schema.Column` + mapped to the value to be rendered into the + VALUES or SET clause:: + + users.insert().values(name="some name") + + users.update().where(users.c.id==5).values(name="some name") + + :param \*args: As an alternative to passing key/value parameters, + a dictionary, tuple, or list of dictionaries or tuples can be passed + as a single positional argument in order to form the VALUES or + SET clause of the statement. The forms that are accepted vary + based on whether this is an :class:`_expression.Insert` or an + :class:`_expression.Update` construct. + + For either an :class:`_expression.Insert` or + :class:`_expression.Update` + construct, a single dictionary can be passed, which works the same as + that of the kwargs form:: + + users.insert().values({"name": "some name"}) + + users.update().values({"name": "some new name"}) + + Also for either form but more typically for the + :class:`_expression.Insert` construct, a tuple that contains an + entry for every column in the table is also accepted:: + + users.insert().values((5, "some name")) + + The :class:`_expression.Insert` construct also supports being + passed a list of dictionaries or full-table-tuples, which on the + server will render the less common SQL syntax of "multiple values" - + this syntax is supported on backends such as SQLite, PostgreSQL, + MySQL, but not necessarily others:: + + users.insert().values([ + {"name": "some name"}, + {"name": "some other name"}, + {"name": "yet another name"}, + ]) + + The above form would render a multiple VALUES statement similar to:: + + INSERT INTO users (name) VALUES + (:name_1), + (:name_2), + (:name_3) + + It is essential to note that **passing multiple values is + NOT the same as using traditional executemany() form**. The above + syntax is a **special** syntax not typically used. To emit an + INSERT statement against multiple rows, the normal method is + to pass a multiple values list to the + :meth:`_engine.Connection.execute` + method, which is supported by all database backends and is generally + more efficient for a very large number of parameters. + + .. seealso:: + + :ref:`tutorial_multiple_parameters` - an introduction to + the traditional Core method of multiple parameter set + invocation for INSERTs and other statements. + + .. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES + clause, even a list of length one, + implies that the :paramref:`_expression.Insert.inline` + flag is set to + True, indicating that the statement will not attempt to fetch + the "last inserted primary key" or other defaults. The + statement deals with an arbitrary number of rows, so the + :attr:`_engine.CursorResult.inserted_primary_key` + accessor does not + apply. + + .. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports + columns with Python side default values and callables in the + same way as that of an "executemany" style of invocation; the + callable is invoked for each row. See :ref:`bug_3288` + for other details. + + The UPDATE construct also supports rendering the SET parameters + in a specific order. For this feature refer to the + :meth:`_expression.Update.ordered_values` method. + + .. seealso:: + + :meth:`_expression.Update.ordered_values` + + + """ + if args: + # positional case. this is currently expensive. we don't + # yet have positional-only args so we have to check the length. + # then we need to check multiparams vs. single dictionary. + # since the parameter format is needed in order to determine + # a cache key, we need to determine this up front. + arg = args[0] + + if kwargs: + raise exc.ArgumentError( + "Can't pass positional and kwargs to values() " + "simultaneously" + ) + elif len(args) > 1: + raise exc.ArgumentError( + "Only a single dictionary/tuple or list of " + "dictionaries/tuples is accepted positionally." + ) + + elif not self._preserve_parameter_order and isinstance( + arg, collections_abc.Sequence + ): + + if arg and isinstance(arg[0], (list, dict, tuple)): + self._multi_values += (arg,) + return + + # tuple values + arg = {c.key: value for c, value in zip(self.table.c, arg)} + elif self._preserve_parameter_order and not isinstance( + arg, collections_abc.Sequence + ): + raise ValueError( + "When preserve_parameter_order is True, " + "values() only accepts a list of 2-tuples" + ) + + else: + # kwarg path. this is the most common path for non-multi-params + # so this is fairly quick. + arg = kwargs + if args: + raise exc.ArgumentError( + "Only a single dictionary/tuple or list of " + "dictionaries/tuples is accepted positionally." + ) + + # for top level values(), convert literals to anonymous bound + # parameters at statement construction time, so that these values can + # participate in the cache key process like any other ClauseElement. + # crud.py now intercepts bound parameters with unique=True from here + # and ensures they get the "crud"-style name when rendered. + + kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs + + if self._preserve_parameter_order: + self._ordered_values = kv_generator(self, arg) + else: + arg = {k: v for k, v in kv_generator(self, arg.items())} + if self._values: + self._values = self._values.union(arg) + else: + self._values = util.immutabledict(arg) + + @_generative + @_exclusive_against( + "_returning", + msgs={ + "_returning": "RETURNING is already configured on this statement" + }, + defaults={"_returning": _returning}, + ) + def return_defaults(self, *cols): + """Make use of a :term:`RETURNING` clause for the purpose + of fetching server-side expressions and defaults. + + E.g.:: + + stmt = table.insert().values(data='newdata').return_defaults() + + result = connection.execute(stmt) + + server_created_at = result.returned_defaults['created_at'] + + When used against a backend that supports RETURNING, all column + values generated by SQL expression or server-side-default will be + added to any existing RETURNING clause, provided that + :meth:`.UpdateBase.returning` is not used simultaneously. The column + values will then be available on the result using the + :attr:`_engine.CursorResult.returned_defaults` accessor as + a dictionary, + referring to values keyed to the :class:`_schema.Column` + object as well as + its ``.key``. + + This method differs from :meth:`.UpdateBase.returning` in these ways: + + 1. :meth:`.ValuesBase.return_defaults` is only intended for use with an + INSERT or an UPDATE statement that matches exactly one row per + parameter set. While the RETURNING construct in the general sense + supports multiple rows for a multi-row UPDATE or DELETE statement, + or for special cases of INSERT that return multiple rows (e.g. + INSERT from SELECT, multi-valued VALUES clause), + :meth:`.ValuesBase.return_defaults` is intended only for an + "ORM-style" single-row INSERT/UPDATE statement. The row + returned by the statement is also consumed implicitly when + :meth:`.ValuesBase.return_defaults` is used. By contrast, + :meth:`.UpdateBase.returning` leaves the RETURNING result-set intact + with a collection of any number of rows. + + 2. It is compatible with the existing logic to fetch auto-generated + primary key values, also known as "implicit returning". Backends + that support RETURNING will automatically make use of RETURNING in + order to fetch the value of newly generated primary keys; while the + :meth:`.UpdateBase.returning` method circumvents this behavior, + :meth:`.ValuesBase.return_defaults` leaves it intact. + + 3. It can be called against any backend. Backends that don't support + RETURNING will skip the usage of the feature, rather than raising + an exception. The return value of + :attr:`_engine.CursorResult.returned_defaults` will be ``None`` + + 4. An INSERT statement invoked with executemany() is supported if the + backend database driver supports the + ``insert_executemany_returning`` feature, currently this includes + PostgreSQL with psycopg2. When executemany is used, the + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors + will return the inserted defaults and primary keys. + + .. versionadded:: 1.4 + + :meth:`.ValuesBase.return_defaults` is used by the ORM to provide + an efficient implementation for the ``eager_defaults`` feature of + :func:`.mapper`. + + :param cols: optional list of column key names or + :class:`_schema.Column` + objects. If omitted, all column expressions evaluated on the server + are added to the returning list. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :meth:`.UpdateBase.returning` + + :attr:`_engine.CursorResult.returned_defaults` + + :attr:`_engine.CursorResult.returned_defaults_rows` + + :attr:`_engine.CursorResult.inserted_primary_key` + + :attr:`_engine.CursorResult.inserted_primary_key_rows` + + """ + self._return_defaults = True + self._return_defaults_columns = cols + + +class Insert(ValuesBase): + """Represent an INSERT construct. + + The :class:`_expression.Insert` object is created using the + :func:`_expression.insert()` function. + + """ + + __visit_name__ = "insert" + + _supports_multi_parameters = True + + select = None + include_insert_from_select_defaults = False + + is_insert = True + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_inline", InternalTraversal.dp_boolean), + ("_select_names", InternalTraversal.dp_string_list), + ("_values", InternalTraversal.dp_dml_values), + ("_multi_values", InternalTraversal.dp_dml_multi_values), + ("select", InternalTraversal.dp_clauseelement), + ("_post_values_clause", InternalTraversal.dp_clauseelement), + ("_returning", InternalTraversal.dp_clauseelement_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_return_defaults", InternalTraversal.dp_boolean), + ( + "_return_defaults_columns", + InternalTraversal.dp_clauseelement_list, + ), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + @ValuesBase._constructor_20_deprecations( + "insert", + "Insert", + [ + "values", + "inline", + "bind", + "prefixes", + "returning", + "return_defaults", + ], + ) + def __init__( + self, + table, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + return_defaults=False, + **dialect_kw + ): + """Construct an :class:`_expression.Insert` object. + + E.g.:: + + from sqlalchemy import insert + + stmt = ( + insert(user_table). + values(name='username', fullname='Full Username') + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.insert` method on + :class:`_schema.Table`. + + .. seealso:: + + :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial` + + + :param table: :class:`_expression.TableClause` + which is the subject of the + insert. + + :param values: collection of values to be inserted; see + :meth:`_expression.Insert.values` + for a description of allowed formats here. + Can be omitted entirely; a :class:`_expression.Insert` construct + will also dynamically render the VALUES clause at execution time + based on the parameters passed to :meth:`_engine.Connection.execute`. + + :param inline: if True, no attempt will be made to retrieve the + SQL-generated default values to be provided within the statement; + in particular, + this allows SQL expressions to be rendered 'inline' within the + statement without the need to pre-execute them beforehand; for + backends that support "returning", this turns off the "implicit + returning" feature for the statement. + + If both :paramref:`_expression.Insert.values` and compile-time bind + parameters are present, the compile-time bind parameters override the + information specified within :paramref:`_expression.Insert.values` on a + per-key basis. + + The keys within :paramref:`_expression.Insert.values` can be either + :class:`~sqlalchemy.schema.Column` objects or their string + identifiers. Each key may reference one of: + + * a literal data value (i.e. string, number, etc.); + * a Column object; + * a SELECT statement. + + If a ``SELECT`` statement is specified which references this + ``INSERT`` statement's table, the statement will be correlated + against the ``INSERT`` statement. + + .. seealso:: + + :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial` + + """ + super(Insert, self).__init__(table, values, prefixes) + self._bind = bind + self._inline = inline + if returning: + self._returning = returning + if dialect_kw: + self._validate_dialect_kwargs_deprecated(dialect_kw) + + if return_defaults: + self._return_defaults = True + if not isinstance(return_defaults, bool): + self._return_defaults_columns = return_defaults + + @_generative + def inline(self): + """Make this :class:`_expression.Insert` construct "inline" . + + When set, no attempt will be made to retrieve the + SQL-generated default values to be provided within the statement; + in particular, + this allows SQL expressions to be rendered 'inline' within the + statement without the need to pre-execute them beforehand; for + backends that support "returning", this turns off the "implicit + returning" feature for the statement. + + + .. versionchanged:: 1.4 the :paramref:`_expression.Insert.inline` + parameter + is now superseded by the :meth:`_expression.Insert.inline` method. + + """ + self._inline = True + + @_generative + def from_select(self, names, select, include_defaults=True): + """Return a new :class:`_expression.Insert` construct which represents + an ``INSERT...FROM SELECT`` statement. + + e.g.:: + + sel = select(table1.c.a, table1.c.b).where(table1.c.c > 5) + ins = table2.insert().from_select(['a', 'b'], sel) + + :param names: a sequence of string column names or + :class:`_schema.Column` + objects representing the target columns. + :param select: a :func:`_expression.select` construct, + :class:`_expression.FromClause` + or other construct which resolves into a + :class:`_expression.FromClause`, + such as an ORM :class:`_query.Query` object, etc. The order of + columns returned from this FROM clause should correspond to the + order of columns sent as the ``names`` parameter; while this + is not checked before passing along to the database, the database + would normally raise an exception if these column lists don't + correspond. + :param include_defaults: if True, non-server default values and + SQL expressions as specified on :class:`_schema.Column` objects + (as documented in :ref:`metadata_defaults_toplevel`) not + otherwise specified in the list of names will be rendered + into the INSERT and SELECT statements, so that these values are also + included in the data to be inserted. + + .. note:: A Python-side default that uses a Python callable function + will only be invoked **once** for the whole statement, and **not + per row**. + + .. versionadded:: 1.0.0 - :meth:`_expression.Insert.from_select` + now renders + Python-side and SQL expression column defaults into the + SELECT statement for columns otherwise not included in the + list of column names. + + .. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT + implies that the :paramref:`_expression.insert.inline` + flag is set to + True, indicating that the statement will not attempt to fetch + the "last inserted primary key" or other defaults. The statement + deals with an arbitrary number of rows, so the + :attr:`_engine.CursorResult.inserted_primary_key` + accessor does not apply. + + """ + + if self._values: + raise exc.InvalidRequestError( + "This construct already inserts value expressions" + ) + + self._select_names = names + self._inline = True + self.include_insert_from_select_defaults = include_defaults + self.select = coercions.expect(roles.DMLSelectRole, select) + + +class DMLWhereBase(object): + _where_criteria = () + + @_generative + def where(self, *whereclause): + """Return a new construct with the given expression(s) added to + its WHERE clause, joined to the existing clause via AND, if any. + + Both :meth:`_dml.Update.where` and :meth:`_dml.Delete.where` + support multiple-table forms, including database-specific + ``UPDATE...FROM`` as well as ``DELETE..USING``. For backends that + don't have multiple-table support, a backend agnostic approach + to using multiple tables is to make use of correlated subqueries. + See the linked tutorial sections below for examples. + + .. seealso:: + + :ref:`tutorial_correlated_updates` + + :ref:`tutorial_update_from` + + :ref:`tutorial_multi_table_deletes` + + """ + + for criterion in whereclause: + where_criteria = coercions.expect(roles.WhereHavingRole, criterion) + self._where_criteria += (where_criteria,) + + def filter(self, *criteria): + """A synonym for the :meth:`_dml.DMLWhereBase.where` method. + + .. versionadded:: 1.4 + + """ + + return self.where(*criteria) + + def _filter_by_zero(self): + return self.table + + def filter_by(self, **kwargs): + r"""apply the given filtering criterion as a WHERE clause + to this select. + + """ + from_entity = self._filter_by_zero() + + clauses = [ + _entity_namespace_key(from_entity, key) == value + for key, value in kwargs.items() + ] + return self.filter(*clauses) + + @property + def whereclause(self): + """Return the completed WHERE clause for this :class:`.DMLWhereBase` + statement. + + This assembles the current collection of WHERE criteria + into a single :class:`_expression.BooleanClauseList` construct. + + + .. versionadded:: 1.4 + + """ + + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) + + +class Update(DMLWhereBase, ValuesBase): + """Represent an Update construct. + + The :class:`_expression.Update` object is created using the + :func:`_expression.update()` function. + + """ + + __visit_name__ = "update" + + is_update = True + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_where_criteria", InternalTraversal.dp_clauseelement_list), + ("_inline", InternalTraversal.dp_boolean), + ("_ordered_values", InternalTraversal.dp_dml_ordered_values), + ("_values", InternalTraversal.dp_dml_values), + ("_returning", InternalTraversal.dp_clauseelement_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_return_defaults", InternalTraversal.dp_boolean), + ( + "_return_defaults_columns", + InternalTraversal.dp_clauseelement_list, + ), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + @ValuesBase._constructor_20_deprecations( + "update", + "Update", + [ + "whereclause", + "values", + "inline", + "bind", + "prefixes", + "returning", + "return_defaults", + "preserve_parameter_order", + ], + ) + def __init__( + self, + table, + whereclause=None, + values=None, + inline=False, + bind=None, + prefixes=None, + returning=None, + return_defaults=False, + preserve_parameter_order=False, + **dialect_kw + ): + r"""Construct an :class:`_expression.Update` object. + + E.g.:: + + from sqlalchemy import update + + stmt = ( + update(user_table). + where(user_table.c.id == 5). + values(name='user #5') + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.update` method on + :class:`_schema.Table`. + + :param table: A :class:`_schema.Table` + object representing the database + table to be updated. + + :param whereclause: Optional SQL expression describing the ``WHERE`` + condition of the ``UPDATE`` statement; is equivalent to using the + more modern :meth:`~Update.where()` method to specify the ``WHERE`` + clause. + + :param values: + Optional dictionary which specifies the ``SET`` conditions of the + ``UPDATE``. If left as ``None``, the ``SET`` + conditions are determined from those parameters passed to the + statement during the execution and/or compilation of the + statement. When compiled standalone without any parameters, + the ``SET`` clause generates for all columns. + + Modern applications may prefer to use the generative + :meth:`_expression.Update.values` method to set the values of the + UPDATE statement. + + :param inline: + if True, SQL defaults present on :class:`_schema.Column` objects via + the ``default`` keyword will be compiled 'inline' into the statement + and not pre-executed. This means that their values will not + be available in the dictionary returned from + :meth:`_engine.CursorResult.last_updated_params`. + + :param preserve_parameter_order: if True, the update statement is + expected to receive parameters **only** via the + :meth:`_expression.Update.values` method, + and they must be passed as a Python + ``list`` of 2-tuples. The rendered UPDATE statement will emit the SET + clause for each referenced column maintaining this order. + + .. versionadded:: 1.0.10 + + .. seealso:: + + :ref:`updates_order_parameters` - illustrates the + :meth:`_expression.Update.ordered_values` method. + + If both ``values`` and compile-time bind parameters are present, the + compile-time bind parameters override the information specified + within ``values`` on a per-key basis. + + The keys within ``values`` can be either :class:`_schema.Column` + objects or their string identifiers (specifically the "key" of the + :class:`_schema.Column`, normally but not necessarily equivalent to + its "name"). Normally, the + :class:`_schema.Column` objects used here are expected to be + part of the target :class:`_schema.Table` that is the table + to be updated. However when using MySQL, a multiple-table + UPDATE statement can refer to columns from any of + the tables referred to in the WHERE clause. + + The values referred to in ``values`` are typically: + + * a literal data value (i.e. string, number, etc.) + * a SQL expression, such as a related :class:`_schema.Column`, + a scalar-returning :func:`_expression.select` construct, + etc. + + When combining :func:`_expression.select` constructs within the + values clause of an :func:`_expression.update` + construct, the subquery represented + by the :func:`_expression.select` should be *correlated* to the + parent table, that is, providing criterion which links the table inside + the subquery to the outer table being updated:: + + users.update().values( + name=select(addresses.c.email_address).\ + where(addresses.c.user_id==users.c.id).\ + scalar_subquery() + ) + + .. seealso:: + + :ref:`inserts_and_updates` - SQL Expression + Language Tutorial + + + """ + self._preserve_parameter_order = preserve_parameter_order + super(Update, self).__init__(table, values, prefixes) + self._bind = bind + if returning: + self._returning = returning + if whereclause is not None: + self._where_criteria += ( + coercions.expect(roles.WhereHavingRole, whereclause), + ) + self._inline = inline + if dialect_kw: + self._validate_dialect_kwargs_deprecated(dialect_kw) + self._return_defaults = return_defaults + + @_generative + def ordered_values(self, *args): + """Specify the VALUES clause of this UPDATE statement with an explicit + parameter ordering that will be maintained in the SET clause of the + resulting UPDATE statement. + + E.g.:: + + stmt = table.update().ordered_values( + ("name", "ed"), ("ident": "foo") + ) + + .. seealso:: + + :ref:`tutorial_parameter_ordered_updates` - full example of the + :meth:`_expression.Update.ordered_values` method. + + .. versionchanged:: 1.4 The :meth:`_expression.Update.ordered_values` + method + supersedes the + :paramref:`_expression.update.preserve_parameter_order` + parameter, which will be removed in SQLAlchemy 2.0. + + """ + if self._values: + raise exc.ArgumentError( + "This statement already has values present" + ) + elif self._ordered_values: + raise exc.ArgumentError( + "This statement already has ordered values present" + ) + + kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs + self._ordered_values = kv_generator(self, args) + + @_generative + def inline(self): + """Make this :class:`_expression.Update` construct "inline" . + + When set, SQL defaults present on :class:`_schema.Column` + objects via the + ``default`` keyword will be compiled 'inline' into the statement and + not pre-executed. This means that their values will not be available + in the dictionary returned from + :meth:`_engine.CursorResult.last_updated_params`. + + .. versionchanged:: 1.4 the :paramref:`_expression.update.inline` + parameter + is now superseded by the :meth:`_expression.Update.inline` method. + + """ + self._inline = True + + +class Delete(DMLWhereBase, UpdateBase): + """Represent a DELETE construct. + + The :class:`_expression.Delete` object is created using the + :func:`_expression.delete()` function. + + """ + + __visit_name__ = "delete" + + is_delete = True + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_where_criteria", InternalTraversal.dp_clauseelement_list), + ("_returning", InternalTraversal.dp_clauseelement_list), + ("_hints", InternalTraversal.dp_table_hint_list), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + @ValuesBase._constructor_20_deprecations( + "delete", + "Delete", + ["whereclause", "values", "bind", "prefixes", "returning"], + ) + def __init__( + self, + table, + whereclause=None, + bind=None, + returning=None, + prefixes=None, + **dialect_kw + ): + r"""Construct :class:`_expression.Delete` object. + + E.g.:: + + from sqlalchemy import delete + + stmt = ( + delete(user_table). + where(user_table.c.id == 5) + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.delete` method on + :class:`_schema.Table`. + + .. seealso:: + + :ref:`inserts_and_updates` - in the + :ref:`1.x tutorial ` + + :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` + + + :param table: The table to delete rows from. + + :param whereclause: Optional SQL expression describing the ``WHERE`` + condition of the ``DELETE`` statement; is equivalent to using the + more modern :meth:`~Delete.where()` method to specify the ``WHERE`` + clause. + + .. seealso:: + + :ref:`deletes` - SQL Expression Tutorial + + """ + self._bind = bind + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) + if returning: + self._returning = returning + + if prefixes: + self._setup_prefixes(prefixes) + + if whereclause is not None: + self._where_criteria += ( + coercions.expect(roles.WhereHavingRole, whereclause), + ) + + if dialect_kw: + self._validate_dialect_kwargs_deprecated(dialect_kw) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py new file mode 100644 index 0000000..268c0d6 --- /dev/null +++ b/lib/sqlalchemy/sql/elements.py @@ -0,0 +1,5415 @@ +# sql/elements.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Core SQL expression elements, including :class:`_expression.ClauseElement`, +:class:`_expression.ColumnElement`, and derived classes. + +""" + +from __future__ import unicode_literals + +import itertools +import operator +import re + +from . import coercions +from . import operators +from . import roles +from . import traversals +from . import type_api +from .annotation import Annotated +from .annotation import SupportsWrappingAnnotations +from .base import _clone +from .base import _generative +from .base import Executable +from .base import HasMemoized +from .base import Immutable +from .base import NO_ARG +from .base import PARSE_AUTOCOMMIT +from .base import SingletonConstant +from .coercions import _document_text_coercion +from .traversals import HasCopyInternals +from .traversals import MemoizedHasCacheKey +from .traversals import NO_CACHE +from .visitors import cloned_traverse +from .visitors import InternalTraversal +from .visitors import traverse +from .visitors import Traversible +from .. import exc +from .. import inspection +from .. import util + + +def collate(expression, collation): + """Return the clause ``expression COLLATE collation``. + + e.g.:: + + collate(mycolumn, 'utf8_bin') + + produces:: + + mycolumn COLLATE utf8_bin + + The collation expression is also quoted if it is a case sensitive + identifier, e.g. contains uppercase characters. + + .. versionchanged:: 1.2 quoting is automatically applied to COLLATE + expressions if they are case sensitive. + + """ + + expr = coercions.expect(roles.ExpressionElementRole, expression) + return BinaryExpression( + expr, CollationClause(collation), operators.collate, type_=expr.type + ) + + +def between(expr, lower_bound, upper_bound, symmetric=False): + """Produce a ``BETWEEN`` predicate clause. + + E.g.:: + + from sqlalchemy import between + stmt = select(users_table).where(between(users_table.c.id, 5, 7)) + + Would produce SQL resembling:: + + SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2 + + The :func:`.between` function is a standalone version of the + :meth:`_expression.ColumnElement.between` method available on all + SQL expressions, as in:: + + stmt = select(users_table).where(users_table.c.id.between(5, 7)) + + All arguments passed to :func:`.between`, including the left side + column expression, are coerced from Python scalar values if a + the value is not a :class:`_expression.ColumnElement` subclass. + For example, + three fixed values can be compared as in:: + + print(between(5, 3, 7)) + + Which would produce:: + + :param_1 BETWEEN :param_2 AND :param_3 + + :param expr: a column expression, typically a + :class:`_expression.ColumnElement` + instance or alternatively a Python scalar expression to be coerced + into a column expression, serving as the left side of the ``BETWEEN`` + expression. + + :param lower_bound: a column or Python scalar expression serving as the + lower bound of the right side of the ``BETWEEN`` expression. + + :param upper_bound: a column or Python scalar expression serving as the + upper bound of the right side of the ``BETWEEN`` expression. + + :param symmetric: if True, will render " BETWEEN SYMMETRIC ". Note + that not all databases support this syntax. + + .. versionadded:: 0.9.5 + + .. seealso:: + + :meth:`_expression.ColumnElement.between` + + """ + expr = coercions.expect(roles.ExpressionElementRole, expr) + return expr.between(lower_bound, upper_bound, symmetric=symmetric) + + +def literal(value, type_=None): + r"""Return a literal clause, bound to a bind parameter. + + Literal clauses are created automatically when non- + :class:`_expression.ClauseElement` objects (such as strings, ints, dates, + etc.) are + used in a comparison operation with a :class:`_expression.ColumnElement` + subclass, + such as a :class:`~sqlalchemy.schema.Column` object. Use this function + to force the generation of a literal clause, which will be created as a + :class:`BindParameter` with a bound value. + + :param value: the value to be bound. Can be any Python object supported by + the underlying DB-API, or is translatable via the given type argument. + + :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` which + will provide bind-parameter translation for this literal. + + """ + return coercions.expect(roles.LiteralValueRole, value, type_=type_) + + +def outparam(key, type_=None): + r"""Create an 'OUT' parameter for usage in functions (stored procedures), + for databases which support them. + + The ``outparam`` can be used like a regular function parameter. + The "output" value will be available from the + :class:`~sqlalchemy.engine.CursorResult` object via its ``out_parameters`` + attribute, which returns a dictionary containing the values. + + """ + return BindParameter(key, None, type_=type_, unique=False, isoutparam=True) + + +def not_(clause): + """Return a negation of the given clause, i.e. ``NOT(clause)``. + + The ``~`` operator is also overloaded on all + :class:`_expression.ColumnElement` subclasses to produce the + same result. + + """ + return operators.inv(coercions.expect(roles.ExpressionElementRole, clause)) + + +@inspection._self_inspects +class ClauseElement( + roles.SQLRole, + SupportsWrappingAnnotations, + MemoizedHasCacheKey, + HasCopyInternals, + Traversible, +): + """Base class for elements of a programmatically constructed SQL + expression. + + """ + + __visit_name__ = "clause" + + _propagate_attrs = util.immutabledict() + """like annotations, however these propagate outwards liberally + as SQL constructs are built, and are set up at construction time. + + """ + + supports_execution = False + + stringify_dialect = "default" + + _from_objects = [] + bind = None + description = None + _is_clone_of = None + + is_clause_element = True + is_selectable = False + + _is_textual = False + _is_from_clause = False + _is_returns_rows = False + _is_text_clause = False + _is_from_container = False + _is_select_container = False + _is_select_statement = False + _is_bind_parameter = False + _is_clause_list = False + _is_lambda_element = False + _is_singleton_constant = False + _is_immutable = False + _is_star = False + + _order_by_label_element = None + + _cache_key_traversal = None + + def _set_propagate_attrs(self, values): + # usually, self._propagate_attrs is empty here. one case where it's + # not is a subquery against ORM select, that is then pulled as a + # property of an aliased class. should all be good + + # assert not self._propagate_attrs + + self._propagate_attrs = util.immutabledict(values) + return self + + def _clone(self, **kw): + """Create a shallow copy of this ClauseElement. + + This method may be used by a generative API. Its also used as + part of the "deep" copy afforded by a traversal that combines + the _copy_internals() method. + + """ + skip = self._memoized_keys + c = self.__class__.__new__(self.__class__) + + if skip: + # ensure this iteration remains atomic + c.__dict__ = { + k: v for k, v in self.__dict__.copy().items() if k not in skip + } + else: + c.__dict__ = self.__dict__.copy() + + # this is a marker that helps to "equate" clauses to each other + # when a Select returns its list of FROM clauses. the cloning + # process leaves around a lot of remnants of the previous clause + # typically in the form of column expressions still attached to the + # old table. + cc = self._is_clone_of + c._is_clone_of = cc if cc is not None else self + return c + + def _negate_in_binary(self, negated_op, original_op): + """a hook to allow the right side of a binary expression to respond + to a negation of the binary expression. + + Used for the special case of expanding bind parameter with IN. + + """ + return self + + def _with_binary_element_type(self, type_): + """in the context of binary expression, convert the type of this + object to the one given. + + applies only to :class:`_expression.ColumnElement` classes. + + """ + return self + + @property + def _constructor(self): + """return the 'constructor' for this ClauseElement. + + This is for the purposes for creating a new object of + this type. Usually, its just the element's __class__. + However, the "Annotated" version of the object overrides + to return the class of its proxied element. + + """ + return self.__class__ + + @HasMemoized.memoized_attribute + def _cloned_set(self): + """Return the set consisting all cloned ancestors of this + ClauseElement. + + Includes this ClauseElement. This accessor tends to be used for + FromClause objects to identify 'equivalent' FROM clauses, regardless + of transformative operations. + + """ + s = util.column_set() + f = self + + # note this creates a cycle, asserted in test_memusage. however, + # turning this into a plain @property adds tends of thousands of method + # calls to Core / ORM performance tests, so the small overhead + # introduced by the relatively small amount of short term cycles + # produced here is preferable + while f is not None: + s.add(f) + f = f._is_clone_of + return s + + @property + def entity_namespace(self): + raise AttributeError( + "This SQL expression has no entity namespace " + "with which to filter from." + ) + + def __getstate__(self): + d = self.__dict__.copy() + d.pop("_is_clone_of", None) + d.pop("_generate_cache_key", None) + return d + + def _execute_on_connection( + self, connection, multiparams, params, execution_options, _force=False + ): + if _force or self.supports_execution: + return connection._execute_clauseelement( + self, multiparams, params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self) + + def unique_params(self, *optionaldict, **kwargs): + """Return a copy with :func:`_expression.bindparam` elements + replaced. + + Same functionality as :meth:`_expression.ClauseElement.params`, + except adds `unique=True` + to affected bind parameters so that multiple statements can be + used. + + """ + return self._replace_params(True, optionaldict, kwargs) + + def params(self, *optionaldict, **kwargs): + """Return a copy with :func:`_expression.bindparam` elements + replaced. + + Returns a copy of this ClauseElement with + :func:`_expression.bindparam` + elements replaced with values taken from the given dictionary:: + + >>> clause = column('x') + bindparam('foo') + >>> print(clause.compile().params) + {'foo':None} + >>> print(clause.params({'foo':7}).compile().params) + {'foo':7} + + """ + return self._replace_params(False, optionaldict, kwargs) + + def _replace_params(self, unique, optionaldict, kwargs): + + if len(optionaldict) == 1: + kwargs.update(optionaldict[0]) + elif len(optionaldict) > 1: + raise exc.ArgumentError( + "params() takes zero or one positional dictionary argument" + ) + + def visit_bindparam(bind): + if bind.key in kwargs: + bind.value = kwargs[bind.key] + bind.required = False + if unique: + bind._convert_to_unique() + + return cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ) + + def compare(self, other, **kw): + r"""Compare this :class:`_expression.ClauseElement` to + the given :class:`_expression.ClauseElement`. + + Subclasses should override the default behavior, which is a + straight identity comparison. + + \**kw are arguments consumed by subclass ``compare()`` methods and + may be used to modify the criteria for comparison + (see :class:`_expression.ColumnElement`). + + """ + return traversals.compare(self, other, **kw) + + def self_group(self, against=None): + """Apply a 'grouping' to this :class:`_expression.ClauseElement`. + + This method is overridden by subclasses to return a "grouping" + construct, i.e. parenthesis. In particular it's used by "binary" + expressions to provide a grouping around themselves when placed into a + larger expression, as well as by :func:`_expression.select` + constructs when placed into the FROM clause of another + :func:`_expression.select`. (Note that subqueries should be + normally created using the :meth:`_expression.Select.alias` method, + as many + platforms require nested SELECT statements to be named). + + As expressions are composed together, the application of + :meth:`self_group` is automatic - end-user code should never + need to use this method directly. Note that SQLAlchemy's + clause constructs take operator precedence into account - + so parenthesis might not be needed, for example, in + an expression like ``x OR (y AND z)`` - AND takes precedence + over OR. + + The base :meth:`self_group` method of + :class:`_expression.ClauseElement` + just returns self. + """ + return self + + def _ungroup(self): + """Return this :class:`_expression.ClauseElement` + without any groupings. + """ + + return self + + @util.preload_module("sqlalchemy.engine.default") + @util.preload_module("sqlalchemy.engine.url") + def compile(self, bind=None, dialect=None, **kw): + """Compile this SQL expression. + + The return value is a :class:`~.Compiled` object. + Calling ``str()`` or ``unicode()`` on the returned value will yield a + string representation of the result. The + :class:`~.Compiled` object also can return a + dictionary of bind parameter names and values + using the ``params`` accessor. + + :param bind: An ``Engine`` or ``Connection`` from which a + ``Compiled`` will be acquired. This argument takes precedence over + this :class:`_expression.ClauseElement`'s bound engine, if any. + + :param column_keys: Used for INSERT and UPDATE statements, a list of + column names which should be present in the VALUES clause of the + compiled statement. If ``None``, all columns from the target table + object are rendered. + + :param dialect: A ``Dialect`` instance from which a ``Compiled`` + will be acquired. This argument takes precedence over the `bind` + argument as well as this :class:`_expression.ClauseElement` + 's bound engine, + if any. + + :param compile_kwargs: optional dictionary of additional parameters + that will be passed through to the compiler within all "visit" + methods. This allows any custom flag to be passed through to + a custom compilation construct, for example. It is also used + for the case of passing the ``literal_binds`` flag through:: + + from sqlalchemy.sql import table, column, select + + t = table('t', column('x')) + + s = select(t).where(t.c.x == 5) + + print(s.compile(compile_kwargs={"literal_binds": True})) + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`faq_sql_expression_string` + + """ + + if not dialect: + if bind: + dialect = bind.dialect + elif self.bind: + dialect = self.bind.dialect + else: + if self.stringify_dialect == "default": + default = util.preloaded.engine_default + dialect = default.StrCompileDialect() + else: + url = util.preloaded.engine_url + dialect = url.URL.create( + self.stringify_dialect + ).get_dialect()() + + return self._compiler(dialect, **kw) + + def _compile_w_cache( + self, + dialect, + compiled_cache=None, + column_keys=None, + for_executemany=False, + schema_translate_map=None, + **kw + ): + if compiled_cache is not None and dialect._supports_statement_cache: + elem_cache_key = self._generate_cache_key() + else: + elem_cache_key = None + + if elem_cache_key: + cache_key, extracted_params = elem_cache_key + key = ( + dialect, + cache_key, + tuple(column_keys), + bool(schema_translate_map), + for_executemany, + ) + compiled_sql = compiled_cache.get(key) + + if compiled_sql is None: + cache_hit = dialect.CACHE_MISS + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + **kw + ) + compiled_cache[key] = compiled_sql + else: + cache_hit = dialect.CACHE_HIT + else: + extracted_params = None + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + **kw + ) + + if not dialect._supports_statement_cache: + cache_hit = dialect.NO_DIALECT_SUPPORT + elif compiled_cache is None: + cache_hit = dialect.CACHING_DISABLED + else: + cache_hit = dialect.NO_CACHE_KEY + + return compiled_sql, extracted_params, cache_hit + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + + return dialect.statement_compiler(dialect, self, **kw) + + def __str__(self): + if util.py3k: + return str(self.compile()) + else: + return unicode(self.compile()).encode( # noqa + "ascii", "backslashreplace" + ) # noqa + + def __invert__(self): + # undocumented element currently used by the ORM for + # relationship.contains() + if hasattr(self, "negation_clause"): + return self.negation_clause + else: + return self._negate() + + def _negate(self): + return UnaryExpression( + self.self_group(against=operators.inv), operator=operators.inv + ) + + def __bool__(self): + raise TypeError("Boolean value of this clause is not defined") + + __nonzero__ = __bool__ + + def __repr__(self): + friendly = self.description + if friendly is None: + return object.__repr__(self) + else: + return "<%s.%s at 0x%x; %s>" % ( + self.__module__, + self.__class__.__name__, + id(self), + friendly, + ) + + +class ColumnElement( + roles.ColumnArgumentOrKeyRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.BinaryElementRole, + roles.OrderByRole, + roles.ColumnsClauseRole, + roles.LimitOffsetRole, + roles.DMLColumnRole, + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + operators.ColumnOperators, + ClauseElement, +): + """Represent a column-oriented SQL expression suitable for usage in the + "columns" clause, WHERE clause etc. of a statement. + + While the most familiar kind of :class:`_expression.ColumnElement` is the + :class:`_schema.Column` object, :class:`_expression.ColumnElement` + serves as the basis + for any unit that may be present in a SQL expression, including + the expressions themselves, SQL functions, bound parameters, + literal expressions, keywords such as ``NULL``, etc. + :class:`_expression.ColumnElement` + is the ultimate base class for all such elements. + + A wide variety of SQLAlchemy Core functions work at the SQL expression + level, and are intended to accept instances of + :class:`_expression.ColumnElement` as + arguments. These functions will typically document that they accept a + "SQL expression" as an argument. What this means in terms of SQLAlchemy + usually refers to an input which is either already in the form of a + :class:`_expression.ColumnElement` object, + or a value which can be **coerced** into + one. The coercion rules followed by most, but not all, SQLAlchemy Core + functions with regards to SQL expressions are as follows: + + * a literal Python value, such as a string, integer or floating + point value, boolean, datetime, ``Decimal`` object, or virtually + any other Python object, will be coerced into a "literal bound + value". This generally means that a :func:`.bindparam` will be + produced featuring the given value embedded into the construct; the + resulting :class:`.BindParameter` object is an instance of + :class:`_expression.ColumnElement`. + The Python value will ultimately be sent + to the DBAPI at execution time as a parameterized argument to the + ``execute()`` or ``executemany()`` methods, after SQLAlchemy + type-specific converters (e.g. those provided by any associated + :class:`.TypeEngine` objects) are applied to the value. + + * any special object value, typically ORM-level constructs, which + feature an accessor called ``__clause_element__()``. The Core + expression system looks for this method when an object of otherwise + unknown type is passed to a function that is looking to coerce the + argument into a :class:`_expression.ColumnElement` and sometimes a + :class:`_expression.SelectBase` expression. + It is used within the ORM to + convert from ORM-specific objects like mapped classes and + mapped attributes into Core expression objects. + + * The Python ``None`` value is typically interpreted as ``NULL``, + which in SQLAlchemy Core produces an instance of :func:`.null`. + + A :class:`_expression.ColumnElement` provides the ability to generate new + :class:`_expression.ColumnElement` + objects using Python expressions. This means that Python operators + such as ``==``, ``!=`` and ``<`` are overloaded to mimic SQL operations, + and allow the instantiation of further :class:`_expression.ColumnElement` + instances + which are composed from other, more fundamental + :class:`_expression.ColumnElement` + objects. For example, two :class:`.ColumnClause` objects can be added + together with the addition operator ``+`` to produce + a :class:`.BinaryExpression`. + Both :class:`.ColumnClause` and :class:`.BinaryExpression` are subclasses + of :class:`_expression.ColumnElement`:: + + >>> from sqlalchemy.sql import column + >>> column('a') + column('b') + + >>> print(column('a') + column('b')) + a + b + + .. seealso:: + + :class:`_schema.Column` + + :func:`_expression.column` + + """ + + __visit_name__ = "column_element" + primary_key = False + foreign_keys = [] + _proxies = () + + _tq_label = None + """The named label that can be used to target + this column in a result set in a "table qualified" context. + + This label is almost always the label used when + rendering AS