summaryrefslogtreecommitdiffstats
path: root/lib/sqlalchemy/ext
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/ext')
-rw-r--r--lib/sqlalchemy/ext/__init__.py11
-rw-r--r--lib/sqlalchemy/ext/associationproxy.py1627
-rw-r--r--lib/sqlalchemy/ext/asyncio/__init__.py22
-rw-r--r--lib/sqlalchemy/ext/asyncio/base.py89
-rw-r--r--lib/sqlalchemy/ext/asyncio/engine.py828
-rw-r--r--lib/sqlalchemy/ext/asyncio/events.py44
-rw-r--r--lib/sqlalchemy/ext/asyncio/exc.py21
-rw-r--r--lib/sqlalchemy/ext/asyncio/result.py671
-rw-r--r--lib/sqlalchemy/ext/asyncio/scoping.py107
-rw-r--r--lib/sqlalchemy/ext/asyncio/session.py759
-rw-r--r--lib/sqlalchemy/ext/automap.py1234
-rw-r--r--lib/sqlalchemy/ext/baked.py648
-rw-r--r--lib/sqlalchemy/ext/compiler.py613
-rw-r--r--lib/sqlalchemy/ext/declarative/__init__.py64
-rw-r--r--lib/sqlalchemy/ext/declarative/extensions.py463
-rw-r--r--lib/sqlalchemy/ext/horizontal_shard.py256
-rw-r--r--lib/sqlalchemy/ext/hybrid.py1206
-rw-r--r--lib/sqlalchemy/ext/indexable.py352
-rw-r--r--lib/sqlalchemy/ext/instrumentation.py416
-rw-r--r--lib/sqlalchemy/ext/mutable.py958
-rw-r--r--lib/sqlalchemy/ext/mypy/__init__.py0
-rw-r--r--lib/sqlalchemy/ext/mypy/apply.py299
-rw-r--r--lib/sqlalchemy/ext/mypy/decl_class.py516
-rw-r--r--lib/sqlalchemy/ext/mypy/infer.py556
-rw-r--r--lib/sqlalchemy/ext/mypy/names.py253
-rw-r--r--lib/sqlalchemy/ext/mypy/plugin.py284
-rw-r--r--lib/sqlalchemy/ext/mypy/util.py305
-rw-r--r--lib/sqlalchemy/ext/orderinglist.py388
-rw-r--r--lib/sqlalchemy/ext/serializer.py177
29 files changed, 13167 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py
new file mode 100644
index 0000000..62bbbf3
--- /dev/null
+++ b/lib/sqlalchemy/ext/__init__.py
@@ -0,0 +1,11 @@
+# ext/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .. import util as _sa_util
+
+
+_sa_util.preloaded.import_prefix("sqlalchemy.ext")
diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py
new file mode 100644
index 0000000..fbf377a
--- /dev/null
+++ b/lib/sqlalchemy/ext/associationproxy.py
@@ -0,0 +1,1627 @@
+# ext/associationproxy.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Contain the ``AssociationProxy`` class.
+
+The ``AssociationProxy`` is a Python property object which provides
+transparent proxied access to the endpoint of an association object.
+
+See the example ``examples/association/proxied_association.py``.
+
+"""
+import operator
+
+from .. import exc
+from .. import inspect
+from .. import orm
+from .. import util
+from ..orm import collections
+from ..orm import interfaces
+from ..sql import or_
+from ..sql.operators import ColumnOperators
+
+
+def association_proxy(target_collection, attr, **kw):
+ r"""Return a Python property implementing a view of a target
+ attribute which references an attribute on members of the
+ target.
+
+ The returned value is an instance of :class:`.AssociationProxy`.
+
+ Implements a Python property representing a relationship as a collection
+ of simpler values, or a scalar value. The proxied property will mimic
+ the collection type of the target (list, dict or set), or, in the case of
+ a one to one relationship, a simple scalar value.
+
+ :param target_collection: Name of the attribute we'll proxy to.
+ This attribute is typically mapped by
+ :func:`~sqlalchemy.orm.relationship` to link to a target collection, but
+ can also be a many-to-one or non-scalar relationship.
+
+ :param attr: Attribute on the associated instance or instances we'll
+ proxy for.
+
+ For example, given a target collection of [obj1, obj2], a list created
+ by this proxy property would look like [getattr(obj1, *attr*),
+ getattr(obj2, *attr*)]
+
+ If the relationship is one-to-one or otherwise uselist=False, then
+ simply: getattr(obj, *attr*)
+
+ :param creator: optional.
+
+ When new items are added to this proxied collection, new instances of
+ the class collected by the target collection will be created. For list
+ and set collections, the target class constructor will be called with
+ the 'value' for the new instance. For dict types, two arguments are
+ passed: key and value.
+
+ If you want to construct instances differently, supply a *creator*
+ function that takes arguments as above and returns instances.
+
+ For scalar relationships, creator() will be called if the target is None.
+ If the target is present, set operations are proxied to setattr() on the
+ associated object.
+
+ If you have an associated object with multiple attributes, you may set
+ up multiple association proxies mapping to different attributes. See
+ the unit tests for examples, and for examples of how creator() functions
+ can be used to construct the scalar relationship on-demand in this
+ situation.
+
+ :param \*\*kw: Passes along any other keyword arguments to
+ :class:`.AssociationProxy`.
+
+ """
+ return AssociationProxy(target_collection, attr, **kw)
+
+
+ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY")
+"""Symbol indicating an :class:`.InspectionAttr` that's
+ of type :class:`.AssociationProxy`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+"""
+
+
+class AssociationProxy(interfaces.InspectionAttrInfo):
+ """A descriptor that presents a read/write view of an object attribute."""
+
+ is_attribute = True
+ extension_type = ASSOCIATION_PROXY
+
+ def __init__(
+ self,
+ target_collection,
+ attr,
+ creator=None,
+ getset_factory=None,
+ proxy_factory=None,
+ proxy_bulk_set=None,
+ info=None,
+ cascade_scalar_deletes=False,
+ ):
+ """Construct a new :class:`.AssociationProxy`.
+
+ The :func:`.association_proxy` function is provided as the usual
+ entrypoint here, though :class:`.AssociationProxy` can be instantiated
+ and/or subclassed directly.
+
+ :param target_collection: Name of the collection we'll proxy to,
+ usually created with :func:`_orm.relationship`.
+
+ :param attr: Attribute on the collected instances we'll proxy
+ for. For example, given a target collection of [obj1, obj2], a
+ list created by this proxy property would look like
+ [getattr(obj1, attr), getattr(obj2, attr)]
+
+ :param creator: Optional. When new items are added to this proxied
+ collection, new instances of the class collected by the target
+ collection will be created. For list and set collections, the
+ target class constructor will be called with the 'value' for the
+ new instance. For dict types, two arguments are passed:
+ key and value.
+
+ If you want to construct instances differently, supply a 'creator'
+ function that takes arguments as above and returns instances.
+
+ :param cascade_scalar_deletes: when True, indicates that setting
+ the proxied value to ``None``, or deleting it via ``del``, should
+ also remove the source object. Only applies to scalar attributes.
+ Normally, removing the proxied target will not remove the proxy
+ source, as this object may have other state that is still to be
+ kept.
+
+ .. versionadded:: 1.3
+
+ .. seealso::
+
+ :ref:`cascade_scalar_deletes` - complete usage example
+
+ :param getset_factory: Optional. Proxied attribute access is
+ automatically handled by routines that get and set values based on
+ the `attr` argument for this proxy.
+
+ If you would like to customize this behavior, you may supply a
+ `getset_factory` callable that produces a tuple of `getter` and
+ `setter` functions. The factory is called with two arguments, the
+ abstract type of the underlying collection and this proxy instance.
+
+ :param proxy_factory: Optional. The type of collection to emulate is
+ determined by sniffing the target collection. If your collection
+ type can't be determined by duck typing or you'd like to use a
+ different collection implementation, you may supply a factory
+ function to produce those collections. Only applicable to
+ non-scalar relationships.
+
+ :param proxy_bulk_set: Optional, use with proxy_factory. See
+ the _set() method for details.
+
+ :param info: optional, will be assigned to
+ :attr:`.AssociationProxy.info` if present.
+
+ .. versionadded:: 1.0.9
+
+ """
+ self.target_collection = target_collection
+ self.value_attr = attr
+ self.creator = creator
+ self.getset_factory = getset_factory
+ self.proxy_factory = proxy_factory
+ self.proxy_bulk_set = proxy_bulk_set
+ self.cascade_scalar_deletes = cascade_scalar_deletes
+
+ self.key = "_%s_%s_%s" % (
+ type(self).__name__,
+ target_collection,
+ id(self),
+ )
+ if info:
+ self.info = info
+
+ def __get__(self, obj, class_):
+ if class_ is None:
+ return self
+ inst = self._as_instance(class_, obj)
+ if inst:
+ return inst.get(obj)
+
+ # obj has to be None here
+ # assert obj is None
+
+ return self
+
+ def __set__(self, obj, values):
+ class_ = type(obj)
+ return self._as_instance(class_, obj).set(obj, values)
+
+ def __delete__(self, obj):
+ class_ = type(obj)
+ return self._as_instance(class_, obj).delete(obj)
+
+ def for_class(self, class_, obj=None):
+ r"""Return the internal state local to a specific mapped class.
+
+ E.g., given a class ``User``::
+
+ class User(Base):
+ # ...
+
+ keywords = association_proxy('kws', 'keyword')
+
+ If we access this :class:`.AssociationProxy` from
+ :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the
+ target class for this proxy as mapped by ``User``::
+
+ inspect(User).all_orm_descriptors["keywords"].for_class(User).target_class
+
+ This returns an instance of :class:`.AssociationProxyInstance` that
+ is specific to the ``User`` class. The :class:`.AssociationProxy`
+ object remains agnostic of its parent class.
+
+ :param class\_: the class that we are returning state for.
+
+ :param obj: optional, an instance of the class that is required
+ if the attribute refers to a polymorphic target, e.g. where we have
+ to look at the type of the actual destination object to get the
+ complete path.
+
+ .. versionadded:: 1.3 - :class:`.AssociationProxy` no longer stores
+ any state specific to a particular parent class; the state is now
+ stored in per-class :class:`.AssociationProxyInstance` objects.
+
+
+ """
+ return self._as_instance(class_, obj)
+
+ def _as_instance(self, class_, obj):
+ try:
+ inst = class_.__dict__[self.key + "_inst"]
+ except KeyError:
+ inst = None
+
+ # avoid exception context
+ if inst is None:
+ owner = self._calc_owner(class_)
+ if owner is not None:
+ inst = AssociationProxyInstance.for_proxy(self, owner, obj)
+ setattr(class_, self.key + "_inst", inst)
+ else:
+ inst = None
+
+ if inst is not None and not inst._is_canonical:
+ # the AssociationProxyInstance can't be generalized
+ # since the proxied attribute is not on the targeted
+ # class, only on subclasses of it, which might be
+ # different. only return for the specific
+ # object's current value
+ return inst._non_canonical_get_for_object(obj)
+ else:
+ return inst
+
+ def _calc_owner(self, target_cls):
+ # we might be getting invoked for a subclass
+ # that is not mapped yet, in some declarative situations.
+ # save until we are mapped
+ try:
+ insp = inspect(target_cls)
+ except exc.NoInspectionAvailable:
+ # can't find a mapper, don't set owner. if we are a not-yet-mapped
+ # subclass, we can also scan through __mro__ to find a mapped
+ # class, but instead just wait for us to be called again against a
+ # mapped class normally.
+ return None
+ else:
+ return insp.mapper.class_manager.class_
+
+ def _default_getset(self, collection_class):
+ attr = self.value_attr
+ _getter = operator.attrgetter(attr)
+
+ def getter(target):
+ return _getter(target) if target is not None else None
+
+ if collection_class is dict:
+
+ def setter(o, k, v):
+ setattr(o, attr, v)
+
+ else:
+
+ def setter(o, v):
+ setattr(o, attr, v)
+
+ return getter, setter
+
+ def __repr__(self):
+ return "AssociationProxy(%r, %r)" % (
+ self.target_collection,
+ self.value_attr,
+ )
+
+
+class AssociationProxyInstance(object):
+ """A per-class object that serves class- and object-specific results.
+
+ This is used by :class:`.AssociationProxy` when it is invoked
+ in terms of a specific class or instance of a class, i.e. when it is
+ used as a regular Python descriptor.
+
+ When referring to the :class:`.AssociationProxy` as a normal Python
+ descriptor, the :class:`.AssociationProxyInstance` is the object that
+ actually serves the information. Under normal circumstances, its presence
+ is transparent::
+
+ >>> User.keywords.scalar
+ False
+
+ In the special case that the :class:`.AssociationProxy` object is being
+ accessed directly, in order to get an explicit handle to the
+ :class:`.AssociationProxyInstance`, use the
+ :meth:`.AssociationProxy.for_class` method::
+
+ proxy_state = inspect(User).all_orm_descriptors["keywords"].for_class(User)
+
+ # view if proxy object is scalar or not
+ >>> proxy_state.scalar
+ False
+
+ .. versionadded:: 1.3
+
+ """ # noqa
+
+ def __init__(self, parent, owning_class, target_class, value_attr):
+ self.parent = parent
+ self.key = parent.key
+ self.owning_class = owning_class
+ self.target_collection = parent.target_collection
+ self.collection_class = None
+ self.target_class = target_class
+ self.value_attr = value_attr
+
+ target_class = None
+ """The intermediary class handled by this
+ :class:`.AssociationProxyInstance`.
+
+ Intercepted append/set/assignment events will result
+ in the generation of new instances of this class.
+
+ """
+
+ @classmethod
+ def for_proxy(cls, parent, owning_class, parent_instance):
+ target_collection = parent.target_collection
+ value_attr = parent.value_attr
+ prop = orm.class_mapper(owning_class).get_property(target_collection)
+
+ # this was never asserted before but this should be made clear.
+ if not isinstance(prop, orm.RelationshipProperty):
+ util.raise_(
+ NotImplementedError(
+ "association proxy to a non-relationship "
+ "intermediary is not supported"
+ ),
+ replace_context=None,
+ )
+
+ target_class = prop.mapper.class_
+
+ try:
+ target_assoc = cls._cls_unwrap_target_assoc_proxy(
+ target_class, value_attr
+ )
+ except AttributeError:
+ # the proxied attribute doesn't exist on the target class;
+ # return an "ambiguous" instance that will work on a per-object
+ # basis
+ return AmbiguousAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ except Exception as err:
+ util.raise_(
+ exc.InvalidRequestError(
+ "Association proxy received an unexpected error when "
+ "trying to retreive attribute "
+ '"%s.%s" from '
+ 'class "%s": %s'
+ % (
+ target_class.__name__,
+ parent.value_attr,
+ target_class.__name__,
+ err,
+ )
+ ),
+ from_=err,
+ )
+ else:
+ return cls._construct_for_assoc(
+ target_assoc, parent, owning_class, target_class, value_attr
+ )
+
+ @classmethod
+ def _construct_for_assoc(
+ cls, target_assoc, parent, owning_class, target_class, value_attr
+ ):
+ if target_assoc is not None:
+ return ObjectAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+
+ attr = getattr(target_class, value_attr)
+ if not hasattr(attr, "_is_internal_proxy"):
+ return AmbiguousAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ is_object = attr._impl_uses_objects
+ if is_object:
+ return ObjectAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+ else:
+ return ColumnAssociationProxyInstance(
+ parent, owning_class, target_class, value_attr
+ )
+
+ def _get_property(self):
+ return orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
+
+ @property
+ def _comparator(self):
+ return self._get_property().comparator
+
+ def __clause_element__(self):
+ raise NotImplementedError(
+ "The association proxy can't be used as a plain column "
+ "expression; it only works inside of a comparison expression"
+ )
+
+ @classmethod
+ def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr):
+ attr = getattr(target_class, value_attr)
+ if isinstance(attr, (AssociationProxy, AssociationProxyInstance)):
+ return attr
+ return None
+
+ @util.memoized_property
+ def _unwrap_target_assoc_proxy(self):
+ return self._cls_unwrap_target_assoc_proxy(
+ self.target_class, self.value_attr
+ )
+
+ @property
+ def remote_attr(self):
+ """The 'remote' class attribute referenced by this
+ :class:`.AssociationProxyInstance`.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.attr`
+
+ :attr:`.AssociationProxyInstance.local_attr`
+
+ """
+ return getattr(self.target_class, self.value_attr)
+
+ @property
+ def local_attr(self):
+ """The 'local' class attribute referenced by this
+ :class:`.AssociationProxyInstance`.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.attr`
+
+ :attr:`.AssociationProxyInstance.remote_attr`
+
+ """
+ return getattr(self.owning_class, self.target_collection)
+
+ @property
+ def attr(self):
+ """Return a tuple of ``(local_attr, remote_attr)``.
+
+ This attribute was originally intended to facilitate using the
+ :meth:`_query.Query.join` method to join across the two relationships
+ at once, however this makes use of a deprecated calling style.
+
+ To use :meth:`_sql.select.join` or :meth:`_orm.Query.join` with
+ an association proxy, the current method is to make use of the
+ :attr:`.AssociationProxyInstance.local_attr` and
+ :attr:`.AssociationProxyInstance.remote_attr` attributes separately::
+
+ stmt = (
+ select(Parent).
+ join(Parent.proxied.local_attr).
+ join(Parent.proxied.remote_attr)
+ )
+
+ A future release may seek to provide a more succinct join pattern
+ for association proxy attributes.
+
+ .. seealso::
+
+ :attr:`.AssociationProxyInstance.local_attr`
+
+ :attr:`.AssociationProxyInstance.remote_attr`
+
+ """
+ return (self.local_attr, self.remote_attr)
+
+ @util.memoized_property
+ def scalar(self):
+ """Return ``True`` if this :class:`.AssociationProxyInstance`
+ proxies a scalar relationship on the local side."""
+
+ scalar = not self._get_property().uselist
+ if scalar:
+ self._initialize_scalar_accessors()
+ return scalar
+
+ @util.memoized_property
+ def _value_is_scalar(self):
+ return (
+ not self._get_property()
+ .mapper.get_property(self.value_attr)
+ .uselist
+ )
+
+ @property
+ def _target_is_object(self):
+ raise NotImplementedError()
+
+ def _initialize_scalar_accessors(self):
+ if self.parent.getset_factory:
+ get, set_ = self.parent.getset_factory(None, self)
+ else:
+ get, set_ = self.parent._default_getset(None)
+ self._scalar_get, self._scalar_set = get, set_
+
+ def _default_getset(self, collection_class):
+ attr = self.value_attr
+ _getter = operator.attrgetter(attr)
+
+ def getter(target):
+ return _getter(target) if target is not None else None
+
+ if collection_class is dict:
+
+ def setter(o, k, v):
+ return setattr(o, attr, v)
+
+ else:
+
+ def setter(o, v):
+ return setattr(o, attr, v)
+
+ return getter, setter
+
+ @property
+ def info(self):
+ return self.parent.info
+
+ def get(self, obj):
+ if obj is None:
+ return self
+
+ if self.scalar:
+ target = getattr(obj, self.target_collection)
+ return self._scalar_get(target)
+ else:
+ try:
+ # If the owning instance is reborn (orm session resurrect,
+ # etc.), refresh the proxy cache.
+ creator_id, self_id, proxy = getattr(obj, self.key)
+ except AttributeError:
+ pass
+ else:
+ if id(obj) == creator_id and id(self) == self_id:
+ assert self.collection_class is not None
+ return proxy
+
+ self.collection_class, proxy = self._new(
+ _lazy_collection(obj, self.target_collection)
+ )
+ setattr(obj, self.key, (id(obj), id(self), proxy))
+ return proxy
+
+ def set(self, obj, values):
+ if self.scalar:
+ creator = (
+ self.parent.creator
+ if self.parent.creator
+ else self.target_class
+ )
+ target = getattr(obj, self.target_collection)
+ if target is None:
+ if values is None:
+ return
+ setattr(obj, self.target_collection, creator(values))
+ else:
+ self._scalar_set(target, values)
+ if values is None and self.parent.cascade_scalar_deletes:
+ setattr(obj, self.target_collection, None)
+ else:
+ proxy = self.get(obj)
+ assert self.collection_class is not None
+ if proxy is not values:
+ proxy._bulk_replace(self, values)
+
+ def delete(self, obj):
+ if self.owning_class is None:
+ self._calc_owner(obj, None)
+
+ if self.scalar:
+ target = getattr(obj, self.target_collection)
+ if target is not None:
+ delattr(target, self.value_attr)
+ delattr(obj, self.target_collection)
+
+ def _new(self, lazy_collection):
+ creator = (
+ self.parent.creator if self.parent.creator else self.target_class
+ )
+ collection_class = util.duck_type_collection(lazy_collection())
+
+ if self.parent.proxy_factory:
+ return (
+ collection_class,
+ self.parent.proxy_factory(
+ lazy_collection, creator, self.value_attr, self
+ ),
+ )
+
+ if self.parent.getset_factory:
+ getter, setter = self.parent.getset_factory(collection_class, self)
+ else:
+ getter, setter = self.parent._default_getset(collection_class)
+
+ if collection_class is list:
+ return (
+ collection_class,
+ _AssociationList(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ elif collection_class is dict:
+ return (
+ collection_class,
+ _AssociationDict(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ elif collection_class is set:
+ return (
+ collection_class,
+ _AssociationSet(
+ lazy_collection, creator, getter, setter, self
+ ),
+ )
+ else:
+ raise exc.ArgumentError(
+ "could not guess which interface to use for "
+ 'collection_class "%s" backing "%s"; specify a '
+ "proxy_factory and proxy_bulk_set manually"
+ % (self.collection_class.__name__, self.target_collection)
+ )
+
+ def _set(self, proxy, values):
+ if self.parent.proxy_bulk_set:
+ self.parent.proxy_bulk_set(proxy, values)
+ elif self.collection_class is list:
+ proxy.extend(values)
+ elif self.collection_class is dict:
+ proxy.update(values)
+ elif self.collection_class is set:
+ proxy.update(values)
+ else:
+ raise exc.ArgumentError(
+ "no proxy_bulk_set supplied for custom "
+ "collection_class implementation"
+ )
+
+ def _inflate(self, proxy):
+ creator = (
+ self.parent.creator and self.parent.creator or self.target_class
+ )
+
+ if self.parent.getset_factory:
+ getter, setter = self.parent.getset_factory(
+ self.collection_class, self
+ )
+ else:
+ getter, setter = self.parent._default_getset(self.collection_class)
+
+ proxy.creator = creator
+ proxy.getter = getter
+ proxy.setter = setter
+
+ def _criterion_exists(self, criterion=None, **kwargs):
+ is_has = kwargs.pop("is_has", None)
+
+ target_assoc = self._unwrap_target_assoc_proxy
+ if target_assoc is not None:
+ inner = target_assoc._criterion_exists(
+ criterion=criterion, **kwargs
+ )
+ return self._comparator._criterion_exists(inner)
+
+ if self._target_is_object:
+ prop = getattr(self.target_class, self.value_attr)
+ value_expr = prop._criterion_exists(criterion, **kwargs)
+ else:
+ if kwargs:
+ raise exc.ArgumentError(
+ "Can't apply keyword arguments to column-targeted "
+ "association proxy; use =="
+ )
+ elif is_has and criterion is not None:
+ raise exc.ArgumentError(
+ "Non-empty has() not allowed for "
+ "column-targeted association proxy; use =="
+ )
+
+ value_expr = criterion
+
+ return self._comparator._criterion_exists(value_expr)
+
+ def any(self, criterion=None, **kwargs):
+ """Produce a proxied 'any' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`
+ and/or :meth:`.RelationshipProperty.Comparator.has`
+ operators of the underlying proxied attributes.
+
+ """
+ if self._unwrap_target_assoc_proxy is None and (
+ self.scalar
+ and (not self._target_is_object or self._value_is_scalar)
+ ):
+ raise exc.InvalidRequestError(
+ "'any()' not implemented for scalar " "attributes. Use has()."
+ )
+ return self._criterion_exists(
+ criterion=criterion, is_has=False, **kwargs
+ )
+
+ def has(self, criterion=None, **kwargs):
+ """Produce a proxied 'has' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`
+ and/or :meth:`.RelationshipProperty.Comparator.has`
+ operators of the underlying proxied attributes.
+
+ """
+ if self._unwrap_target_assoc_proxy is None and (
+ not self.scalar
+ or (self._target_is_object and not self._value_is_scalar)
+ ):
+ raise exc.InvalidRequestError(
+ "'has()' not implemented for collections. " "Use any()."
+ )
+ return self._criterion_exists(
+ criterion=criterion, is_has=True, **kwargs
+ )
+
+ def __repr__(self):
+ return "%s(%r)" % (self.__class__.__name__, self.parent)
+
+
+class AmbiguousAssociationProxyInstance(AssociationProxyInstance):
+ """an :class:`.AssociationProxyInstance` where we cannot determine
+ the type of target object.
+ """
+
+ _is_canonical = False
+
+ def _ambiguous(self):
+ raise AttributeError(
+ "Association proxy %s.%s refers to an attribute '%s' that is not "
+ "directly mapped on class %s; therefore this operation cannot "
+ "proceed since we don't know what type of object is referred "
+ "towards"
+ % (
+ self.owning_class.__name__,
+ self.target_collection,
+ self.value_attr,
+ self.target_class,
+ )
+ )
+
+ def get(self, obj):
+ if obj is None:
+ return self
+ else:
+ return super(AmbiguousAssociationProxyInstance, self).get(obj)
+
+ def __eq__(self, obj):
+ self._ambiguous()
+
+ def __ne__(self, obj):
+ self._ambiguous()
+
+ def any(self, criterion=None, **kwargs):
+ self._ambiguous()
+
+ def has(self, criterion=None, **kwargs):
+ self._ambiguous()
+
+ @util.memoized_property
+ def _lookup_cache(self):
+ # mapping of <subclass>->AssociationProxyInstance.
+ # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist;
+ # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2
+ return {}
+
+ def _non_canonical_get_for_object(self, parent_instance):
+ if parent_instance is not None:
+ actual_obj = getattr(parent_instance, self.target_collection)
+ if actual_obj is not None:
+ try:
+ insp = inspect(actual_obj)
+ except exc.NoInspectionAvailable:
+ pass
+ else:
+ mapper = insp.mapper
+ instance_class = mapper.class_
+ if instance_class not in self._lookup_cache:
+ self._populate_cache(instance_class, mapper)
+
+ try:
+ return self._lookup_cache[instance_class]
+ except KeyError:
+ pass
+
+ # no object or ambiguous object given, so return "self", which
+ # is a proxy with generally only instance-level functionality
+ return self
+
+ def _populate_cache(self, instance_class, mapper):
+ prop = orm.class_mapper(self.owning_class).get_property(
+ self.target_collection
+ )
+
+ if mapper.isa(prop.mapper):
+ target_class = instance_class
+ try:
+ target_assoc = self._cls_unwrap_target_assoc_proxy(
+ target_class, self.value_attr
+ )
+ except AttributeError:
+ pass
+ else:
+ self._lookup_cache[instance_class] = self._construct_for_assoc(
+ target_assoc,
+ self.parent,
+ self.owning_class,
+ target_class,
+ self.value_attr,
+ )
+
+
+class ObjectAssociationProxyInstance(AssociationProxyInstance):
+ """an :class:`.AssociationProxyInstance` that has an object as a target."""
+
+ _target_is_object = True
+ _is_canonical = True
+
+ def contains(self, obj):
+ """Produce a proxied 'contains' expression using EXISTS.
+
+ This expression will be a composed product
+ using the :meth:`.RelationshipProperty.Comparator.any`,
+ :meth:`.RelationshipProperty.Comparator.has`,
+ and/or :meth:`.RelationshipProperty.Comparator.contains`
+ operators of the underlying proxied attributes.
+ """
+
+ target_assoc = self._unwrap_target_assoc_proxy
+ if target_assoc is not None:
+ return self._comparator._criterion_exists(
+ target_assoc.contains(obj)
+ if not target_assoc.scalar
+ else target_assoc == obj
+ )
+ elif (
+ self._target_is_object
+ and self.scalar
+ and not self._value_is_scalar
+ ):
+ return self._comparator.has(
+ getattr(self.target_class, self.value_attr).contains(obj)
+ )
+ elif self._target_is_object and self.scalar and self._value_is_scalar:
+ raise exc.InvalidRequestError(
+ "contains() doesn't apply to a scalar object endpoint; use =="
+ )
+ else:
+
+ return self._comparator._criterion_exists(**{self.value_attr: obj})
+
+ def __eq__(self, obj):
+ # note the has() here will fail for collections; eq_()
+ # is only allowed with a scalar.
+ if obj is None:
+ return or_(
+ self._comparator.has(**{self.value_attr: obj}),
+ self._comparator == None,
+ )
+ else:
+ return self._comparator.has(**{self.value_attr: obj})
+
+ def __ne__(self, obj):
+ # note the has() here will fail for collections; eq_()
+ # is only allowed with a scalar.
+ return self._comparator.has(
+ getattr(self.target_class, self.value_attr) != obj
+ )
+
+
+class ColumnAssociationProxyInstance(
+ ColumnOperators, AssociationProxyInstance
+):
+ """an :class:`.AssociationProxyInstance` that has a database column as a
+ target.
+ """
+
+ _target_is_object = False
+ _is_canonical = True
+
+ def __eq__(self, other):
+ # special case "is None" to check for no related row as well
+ expr = self._criterion_exists(
+ self.remote_attr.operate(operator.eq, other)
+ )
+ if other is None:
+ return or_(expr, self._comparator == None)
+ else:
+ return expr
+
+ def operate(self, op, *other, **kwargs):
+ return self._criterion_exists(
+ self.remote_attr.operate(op, *other, **kwargs)
+ )
+
+
+class _lazy_collection(object):
+ def __init__(self, obj, target):
+ self.parent = obj
+ self.target = target
+
+ def __call__(self):
+ return getattr(self.parent, self.target)
+
+ def __getstate__(self):
+ return {"obj": self.parent, "target": self.target}
+
+ def __setstate__(self, state):
+ self.parent = state["obj"]
+ self.target = state["target"]
+
+
+class _AssociationCollection(object):
+ def __init__(self, lazy_collection, creator, getter, setter, parent):
+ """Constructs an _AssociationCollection.
+
+ This will always be a subclass of either _AssociationList,
+ _AssociationSet, or _AssociationDict.
+
+ lazy_collection
+ A callable returning a list-based collection of entities (usually an
+ object attribute managed by a SQLAlchemy relationship())
+
+ creator
+ A function that creates new target entities. Given one parameter:
+ value. This assertion is assumed::
+
+ obj = creator(somevalue)
+ assert getter(obj) == somevalue
+
+ getter
+ A function. Given an associated object, return the 'value'.
+
+ setter
+ A function. Given an associated object and a value, store that
+ value on the object.
+
+ """
+ self.lazy_collection = lazy_collection
+ self.creator = creator
+ self.getter = getter
+ self.setter = setter
+ self.parent = parent
+
+ col = property(lambda self: self.lazy_collection())
+
+ def __len__(self):
+ return len(self.col)
+
+ def __bool__(self):
+ return bool(self.col)
+
+ __nonzero__ = __bool__
+
+ def __getstate__(self):
+ return {"parent": self.parent, "lazy_collection": self.lazy_collection}
+
+ def __setstate__(self, state):
+ self.parent = state["parent"]
+ self.lazy_collection = state["lazy_collection"]
+ self.parent._inflate(self)
+
+ def _bulk_replace(self, assoc_proxy, values):
+ self.clear()
+ assoc_proxy._set(self, values)
+
+
+class _AssociationList(_AssociationCollection):
+ """Generic, converting, list-to-list proxy."""
+
+ def _create(self, value):
+ return self.creator(value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def _set(self, object_, value):
+ return self.setter(object_, value)
+
+ def __getitem__(self, index):
+ if not isinstance(index, slice):
+ return self._get(self.col[index])
+ else:
+ return [self._get(member) for member in self.col[index]]
+
+ def __setitem__(self, index, value):
+ if not isinstance(index, slice):
+ self._set(self.col[index], value)
+ else:
+ if index.stop is None:
+ stop = len(self)
+ elif index.stop < 0:
+ stop = len(self) + index.stop
+ else:
+ stop = index.stop
+ step = index.step or 1
+
+ start = index.start or 0
+ rng = list(range(index.start or 0, stop, step))
+ if step == 1:
+ for i in rng:
+ del self[start]
+ i = start
+ for item in value:
+ self.insert(i, item)
+ i += 1
+ else:
+ if len(value) != len(rng):
+ raise ValueError(
+ "attempt to assign sequence of size %s to "
+ "extended slice of size %s" % (len(value), len(rng))
+ )
+ for i, item in zip(rng, value):
+ self._set(self.col[i], item)
+
+ def __delitem__(self, index):
+ del self.col[index]
+
+ def __contains__(self, value):
+ for member in self.col:
+ # testlib.pragma exempt:__eq__
+ if self._get(member) == value:
+ return True
+ return False
+
+ def __getslice__(self, start, end):
+ return [self._get(member) for member in self.col[start:end]]
+
+ def __setslice__(self, start, end, values):
+ members = [self._create(v) for v in values]
+ self.col[start:end] = members
+
+ def __delslice__(self, start, end):
+ del self.col[start:end]
+
+ def __iter__(self):
+ """Iterate over proxied values.
+
+ For the actual domain objects, iterate over .col instead or
+ just use the underlying collection directly from its property
+ on the parent.
+ """
+
+ for member in self.col:
+ yield self._get(member)
+ return
+
+ def append(self, value):
+ col = self.col
+ item = self._create(value)
+ col.append(item)
+
+ def count(self, value):
+ return sum(
+ [
+ 1
+ for _ in util.itertools_filter(
+ lambda v: v == value, iter(self)
+ )
+ ]
+ )
+
+ def extend(self, values):
+ for v in values:
+ self.append(v)
+
+ def insert(self, index, value):
+ self.col[index:index] = [self._create(value)]
+
+ def pop(self, index=-1):
+ return self.getter(self.col.pop(index))
+
+ def remove(self, value):
+ for i, val in enumerate(self):
+ if val == value:
+ del self.col[i]
+ return
+ raise ValueError("value not in list")
+
+ def reverse(self):
+ """Not supported, use reversed(mylist)"""
+
+ raise NotImplementedError
+
+ def sort(self):
+ """Not supported, use sorted(mylist)"""
+
+ raise NotImplementedError
+
+ def clear(self):
+ del self.col[0 : len(self.col)]
+
+ def __eq__(self, other):
+ return list(self) == other
+
+ def __ne__(self, other):
+ return list(self) != other
+
+ def __lt__(self, other):
+ return list(self) < other
+
+ def __le__(self, other):
+ return list(self) <= other
+
+ def __gt__(self, other):
+ return list(self) > other
+
+ def __ge__(self, other):
+ return list(self) >= other
+
+ def __cmp__(self, other):
+ return util.cmp(list(self), other)
+
+ def __add__(self, iterable):
+ try:
+ other = list(iterable)
+ except TypeError:
+ return NotImplemented
+ return list(self) + other
+
+ def __radd__(self, iterable):
+ try:
+ other = list(iterable)
+ except TypeError:
+ return NotImplemented
+ return other + list(self)
+
+ def __mul__(self, n):
+ if not isinstance(n, int):
+ return NotImplemented
+ return list(self) * n
+
+ __rmul__ = __mul__
+
+ def __iadd__(self, iterable):
+ self.extend(iterable)
+ return self
+
+ def __imul__(self, n):
+ # unlike a regular list *=, proxied __imul__ will generate unique
+ # backing objects for each copy. *= on proxied lists is a bit of
+ # a stretch anyhow, and this interpretation of the __imul__ contract
+ # is more plausibly useful than copying the backing objects.
+ if not isinstance(n, int):
+ return NotImplemented
+ if n == 0:
+ self.clear()
+ elif n > 1:
+ self.extend(list(self) * (n - 1))
+ return self
+
+ def index(self, item, *args):
+ return list(self).index(item, *args)
+
+ def copy(self):
+ return list(self)
+
+ def __repr__(self):
+ return repr(list(self))
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
+ func.__doc__ = getattr(list, func_name).__doc__
+ del func_name, func
+
+
+_NotProvided = util.symbol("_NotProvided")
+
+
+class _AssociationDict(_AssociationCollection):
+ """Generic, converting, dict-to-dict proxy."""
+
+ def _create(self, key, value):
+ return self.creator(key, value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def _set(self, object_, key, value):
+ return self.setter(object_, key, value)
+
+ def __getitem__(self, key):
+ return self._get(self.col[key])
+
+ def __setitem__(self, key, value):
+ if key in self.col:
+ self._set(self.col[key], key, value)
+ else:
+ self.col[key] = self._create(key, value)
+
+ def __delitem__(self, key):
+ del self.col[key]
+
+ def __contains__(self, key):
+ # testlib.pragma exempt:__hash__
+ return key in self.col
+
+ def has_key(self, key):
+ # testlib.pragma exempt:__hash__
+ return key in self.col
+
+ def __iter__(self):
+ return iter(self.col.keys())
+
+ def clear(self):
+ self.col.clear()
+
+ def __eq__(self, other):
+ return dict(self) == other
+
+ def __ne__(self, other):
+ return dict(self) != other
+
+ def __lt__(self, other):
+ return dict(self) < other
+
+ def __le__(self, other):
+ return dict(self) <= other
+
+ def __gt__(self, other):
+ return dict(self) > other
+
+ def __ge__(self, other):
+ return dict(self) >= other
+
+ def __cmp__(self, other):
+ return util.cmp(dict(self), other)
+
+ def __repr__(self):
+ return repr(dict(self.items()))
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def setdefault(self, key, default=None):
+ if key not in self.col:
+ self.col[key] = self._create(key, default)
+ return default
+ else:
+ return self[key]
+
+ def keys(self):
+ return self.col.keys()
+
+ if util.py2k:
+
+ def iteritems(self):
+ return ((key, self._get(self.col[key])) for key in self.col)
+
+ def itervalues(self):
+ return (self._get(self.col[key]) for key in self.col)
+
+ def iterkeys(self):
+ return self.col.iterkeys()
+
+ def values(self):
+ return [self._get(member) for member in self.col.values()]
+
+ def items(self):
+ return [(k, self._get(self.col[k])) for k in self]
+
+ else:
+
+ def items(self):
+ return ((key, self._get(self.col[key])) for key in self.col)
+
+ def values(self):
+ return (self._get(self.col[key]) for key in self.col)
+
+ def pop(self, key, default=_NotProvided):
+ if default is _NotProvided:
+ member = self.col.pop(key)
+ else:
+ member = self.col.pop(key, default)
+ return self._get(member)
+
+ def popitem(self):
+ item = self.col.popitem()
+ return (item[0], self._get(item[1]))
+
+ def update(self, *a, **kw):
+ if len(a) > 1:
+ raise TypeError(
+ "update expected at most 1 arguments, got %i" % len(a)
+ )
+ elif len(a) == 1:
+ seq_or_map = a[0]
+ # discern dict from sequence - took the advice from
+ # https://www.voidspace.org.uk/python/articles/duck_typing.shtml
+ # still not perfect :(
+ if hasattr(seq_or_map, "keys"):
+ for item in seq_or_map:
+ self[item] = seq_or_map[item]
+ else:
+ try:
+ for k, v in seq_or_map:
+ self[k] = v
+ except ValueError as err:
+ util.raise_(
+ ValueError(
+ "dictionary update sequence "
+ "requires 2-element tuples"
+ ),
+ replace_context=err,
+ )
+
+ for key, value in kw:
+ self[key] = value
+
+ def _bulk_replace(self, assoc_proxy, values):
+ existing = set(self)
+ constants = existing.intersection(values or ())
+ additions = set(values or ()).difference(constants)
+ removals = existing.difference(constants)
+
+ for key, member in values.items() or ():
+ if key in additions:
+ self[key] = member
+ elif key in constants:
+ self[key] = member
+
+ for key in removals:
+ del self[key]
+
+ def copy(self):
+ return dict(self.items())
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(dict, func_name)
+ ):
+ func.__doc__ = getattr(dict, func_name).__doc__
+ del func_name, func
+
+
+class _AssociationSet(_AssociationCollection):
+ """Generic, converting, set-to-set proxy."""
+
+ def _create(self, value):
+ return self.creator(value)
+
+ def _get(self, object_):
+ return self.getter(object_)
+
+ def __len__(self):
+ return len(self.col)
+
+ def __bool__(self):
+ if self.col:
+ return True
+ else:
+ return False
+
+ __nonzero__ = __bool__
+
+ def __contains__(self, value):
+ for member in self.col:
+ # testlib.pragma exempt:__eq__
+ if self._get(member) == value:
+ return True
+ return False
+
+ def __iter__(self):
+ """Iterate over proxied values.
+
+ For the actual domain objects, iterate over .col instead or just use
+ the underlying collection directly from its property on the parent.
+
+ """
+ for member in self.col:
+ yield self._get(member)
+ return
+
+ def add(self, value):
+ if value not in self:
+ self.col.add(self._create(value))
+
+ # for discard and remove, choosing a more expensive check strategy rather
+ # than call self.creator()
+ def discard(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ self.col.discard(member)
+ break
+
+ def remove(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ self.col.discard(member)
+ return
+ raise KeyError(value)
+
+ def pop(self):
+ if not self.col:
+ raise KeyError("pop from an empty set")
+ member = self.col.pop()
+ return self._get(member)
+
+ def update(self, other):
+ for value in other:
+ self.add(value)
+
+ def _bulk_replace(self, assoc_proxy, values):
+ existing = set(self)
+ constants = existing.intersection(values or ())
+ additions = set(values or ()).difference(constants)
+ removals = existing.difference(constants)
+
+ appender = self.add
+ remover = self.remove
+
+ for member in values or ():
+ if member in additions:
+ appender(member)
+ elif member in constants:
+ appender(member)
+
+ for member in removals:
+ remover(member)
+
+ def __ior__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ for value in other:
+ self.add(value)
+ return self
+
+ def _set(self):
+ return set(iter(self))
+
+ def union(self, other):
+ return set(self).union(other)
+
+ __or__ = union
+
+ def difference(self, other):
+ return set(self).difference(other)
+
+ __sub__ = difference
+
+ def difference_update(self, other):
+ for value in other:
+ self.discard(value)
+
+ def __isub__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ for value in other:
+ self.discard(value)
+ return self
+
+ def intersection(self, other):
+ return set(self).intersection(other)
+
+ __and__ = intersection
+
+ def intersection_update(self, other):
+ want, have = self.intersection(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+
+ def __iand__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.intersection(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+ return self
+
+ def symmetric_difference(self, other):
+ return set(self).symmetric_difference(other)
+
+ __xor__ = symmetric_difference
+
+ def symmetric_difference_update(self, other):
+ want, have = self.symmetric_difference(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+
+ def __ixor__(self, other):
+ if not collections._set_binops_check_strict(self, other):
+ return NotImplemented
+ want, have = self.symmetric_difference(other), set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+ return self
+
+ def issubset(self, other):
+ return set(self).issubset(other)
+
+ def issuperset(self, other):
+ return set(self).issuperset(other)
+
+ def clear(self):
+ self.col.clear()
+
+ def copy(self):
+ return set(self)
+
+ def __eq__(self, other):
+ return set(self) == other
+
+ def __ne__(self, other):
+ return set(self) != other
+
+ def __lt__(self, other):
+ return set(self) < other
+
+ def __le__(self, other):
+ return set(self) <= other
+
+ def __gt__(self, other):
+ return set(self) > other
+
+ def __ge__(self, other):
+ return set(self) >= other
+
+ def __repr__(self):
+ return repr(set(self))
+
+ def __hash__(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(set, func_name)
+ ):
+ func.__doc__ = getattr(set, func_name).__doc__
+ del func_name, func
diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py
new file mode 100644
index 0000000..15b2cb0
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/__init__.py
@@ -0,0 +1,22 @@
+# ext/asyncio/__init__.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .engine import async_engine_from_config
+from .engine import AsyncConnection
+from .engine import AsyncEngine
+from .engine import AsyncTransaction
+from .engine import create_async_engine
+from .events import AsyncConnectionEvents
+from .events import AsyncSessionEvents
+from .result import AsyncMappingResult
+from .result import AsyncResult
+from .result import AsyncScalarResult
+from .scoping import async_scoped_session
+from .session import async_object_session
+from .session import async_session
+from .session import AsyncSession
+from .session import AsyncSessionTransaction
diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py
new file mode 100644
index 0000000..3f77f55
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/base.py
@@ -0,0 +1,89 @@
+import abc
+import functools
+import weakref
+
+from . import exc as async_exc
+
+
+class ReversibleProxy:
+ # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
+ _proxy_objects = {}
+ __slots__ = ("__weakref__",)
+
+ def _assign_proxied(self, target):
+ if target is not None:
+ target_ref = weakref.ref(target, ReversibleProxy._target_gced)
+ proxy_ref = weakref.ref(
+ self,
+ functools.partial(ReversibleProxy._target_gced, target_ref),
+ )
+ ReversibleProxy._proxy_objects[target_ref] = proxy_ref
+
+ return target
+
+ @classmethod
+ def _target_gced(cls, ref, proxy_ref=None):
+ cls._proxy_objects.pop(ref, None)
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ raise NotImplementedError()
+
+ @classmethod
+ def _retrieve_proxy_for_target(cls, target, regenerate=True):
+ try:
+ proxy_ref = cls._proxy_objects[weakref.ref(target)]
+ except KeyError:
+ pass
+ else:
+ proxy = proxy_ref()
+ if proxy is not None:
+ return proxy
+
+ if regenerate:
+ return cls._regenerate_proxy_for_target(target)
+ else:
+ return None
+
+
+class StartableContext(abc.ABC):
+ __slots__ = ()
+
+ @abc.abstractmethod
+ async def start(self, is_ctxmanager=False):
+ pass
+
+ def __await__(self):
+ return self.start().__await__()
+
+ async def __aenter__(self):
+ return await self.start(is_ctxmanager=True)
+
+ @abc.abstractmethod
+ async def __aexit__(self, type_, value, traceback):
+ pass
+
+ def _raise_for_not_started(self):
+ raise async_exc.AsyncContextNotStarted(
+ "%s context has not been started and object has not been awaited."
+ % (self.__class__.__name__)
+ )
+
+
+class ProxyComparable(ReversibleProxy):
+ __slots__ = ()
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, self.__class__)
+ and self._proxied == other._proxied
+ )
+
+ def __ne__(self, other):
+ return (
+ not isinstance(other, self.__class__)
+ or self._proxied != other._proxied
+ )
diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py
new file mode 100644
index 0000000..4fbe4f7
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/engine.py
@@ -0,0 +1,828 @@
+# ext/asyncio/engine.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+import asyncio
+
+from . import exc as async_exc
+from .base import ProxyComparable
+from .base import StartableContext
+from .result import _ensure_sync_result
+from .result import AsyncResult
+from ... import exc
+from ... import inspection
+from ... import util
+from ...engine import create_engine as _create_engine
+from ...engine.base import NestedTransaction
+from ...future import Connection
+from ...future import Engine
+from ...util.concurrency import greenlet_spawn
+
+
+def create_async_engine(*arg, **kw):
+ """Create a new async engine instance.
+
+ Arguments passed to :func:`_asyncio.create_async_engine` are mostly
+ identical to those passed to the :func:`_sa.create_engine` function.
+ The specified dialect must be an asyncio-compatible dialect
+ such as :ref:`dialect-postgresql-asyncpg`.
+
+ .. versionadded:: 1.4
+
+ """
+
+ if kw.get("server_side_cursors", False):
+ raise async_exc.AsyncMethodRequired(
+ "Can't set server_side_cursors for async engine globally; "
+ "use the connection.stream() method for an async "
+ "streaming result set"
+ )
+ kw["future"] = True
+ sync_engine = _create_engine(*arg, **kw)
+ return AsyncEngine(sync_engine)
+
+
+def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs):
+ """Create a new AsyncEngine instance using a configuration dictionary.
+
+ This function is analogous to the :func:`_sa.engine_from_config` function
+ in SQLAlchemy Core, except that the requested dialect must be an
+ asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`.
+ The argument signature of the function is identical to that
+ of :func:`_sa.engine_from_config`.
+
+ .. versionadded:: 1.4.29
+
+ """
+ options = {
+ key[len(prefix) :]: value
+ for key, value in configuration.items()
+ if key.startswith(prefix)
+ }
+ options["_coerce_config"] = True
+ options.update(kwargs)
+ url = options.pop("url")
+ return create_async_engine(url, **options)
+
+
+class AsyncConnectable:
+ __slots__ = "_slots_dispatch", "__weakref__"
+
+
+@util.create_proxy_methods(
+ Connection,
+ ":class:`_future.Connection`",
+ ":class:`_asyncio.AsyncConnection`",
+ classmethods=[],
+ methods=[],
+ attributes=[
+ "closed",
+ "invalidated",
+ "dialect",
+ "default_isolation_level",
+ ],
+)
+class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
+ """An asyncio proxy for a :class:`_engine.Connection`.
+
+ :class:`_asyncio.AsyncConnection` is acquired using the
+ :meth:`_asyncio.AsyncEngine.connect`
+ method of :class:`_asyncio.AsyncEngine`::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+ async with engine.connect() as conn:
+ result = await conn.execute(select(table))
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ # AsyncConnection is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncConnection that matches this one given only the
+ # "sync" elements.
+ __slots__ = (
+ "engine",
+ "sync_engine",
+ "sync_connection",
+ )
+
+ def __init__(self, async_engine, sync_connection=None):
+ self.engine = async_engine
+ self.sync_engine = async_engine.sync_engine
+ self.sync_connection = self._assign_proxied(sync_connection)
+
+ sync_connection: Connection
+ """Reference to the sync-style :class:`_engine.Connection` this
+ :class:`_asyncio.AsyncConnection` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ sync_engine: Engine
+ """Reference to the sync-style :class:`_engine.Engine` this
+ :class:`_asyncio.AsyncConnection` is associated with via its underlying
+ :class:`_engine.Connection`.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncConnection(
+ AsyncEngine._retrieve_proxy_for_target(target.engine), target
+ )
+
+ async def start(self, is_ctxmanager=False):
+ """Start this :class:`_asyncio.AsyncConnection` object's context
+ outside of using a Python ``with:`` block.
+
+ """
+ if self.sync_connection:
+ raise exc.InvalidRequestError("connection is already started")
+ self.sync_connection = self._assign_proxied(
+ await (greenlet_spawn(self.sync_engine.connect))
+ )
+ return self
+
+ @property
+ def connection(self):
+ """Not implemented for async; call
+ :meth:`_asyncio.AsyncConnection.get_raw_connection`.
+ """
+ raise exc.InvalidRequestError(
+ "AsyncConnection.connection accessor is not implemented as the "
+ "attribute may need to reconnect on an invalidated connection. "
+ "Use the get_raw_connection() method."
+ )
+
+ async def get_raw_connection(self):
+ """Return the pooled DBAPI-level connection in use by this
+ :class:`_asyncio.AsyncConnection`.
+
+ This is a SQLAlchemy connection-pool proxied connection
+ which then has the attribute
+ :attr:`_pool._ConnectionFairy.driver_connection` that refers to the
+ actual driver connection. Its
+ :attr:`_pool._ConnectionFairy.dbapi_connection` refers instead
+ to an :class:`_engine.AdaptedConnection` instance that
+ adapts the driver connection to the DBAPI protocol.
+
+ """
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(getattr, conn, "connection")
+
+ @property
+ def _proxied(self):
+ return self.sync_connection
+
+ @property
+ def info(self):
+ """Return the :attr:`_engine.Connection.info` dictionary of the
+ underlying :class:`_engine.Connection`.
+
+ This dictionary is freely writable for user-defined state to be
+ associated with the database connection.
+
+ This attribute is only available if the :class:`.AsyncConnection` is
+ currently connected. If the :attr:`.AsyncConnection.closed` attribute
+ is ``True``, then accessing this attribute will raise
+ :class:`.ResourceClosedError`.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ return self.sync_connection.info
+
+ def _sync_connection(self):
+ if not self.sync_connection:
+ self._raise_for_not_started()
+ return self.sync_connection
+
+ def begin(self):
+ """Begin a transaction prior to autobegin occurring."""
+ self._sync_connection()
+ return AsyncTransaction(self)
+
+ def begin_nested(self):
+ """Begin a nested transaction and return a transaction handle."""
+ self._sync_connection()
+ return AsyncTransaction(self, nested=True)
+
+ async def invalidate(self, exception=None):
+ """Invalidate the underlying DBAPI connection associated with
+ this :class:`_engine.Connection`.
+
+ See the method :meth:`_engine.Connection.invalidate` for full
+ detail on this method.
+
+ """
+
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.invalidate, exception=exception)
+
+ async def get_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ async def set_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ def in_transaction(self):
+ """Return True if a transaction is in progress.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+
+ conn = self._sync_connection()
+
+ return conn.in_transaction()
+
+ def in_nested_transaction(self):
+ """Return True if a transaction is in progress.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ return conn.in_nested_transaction()
+
+ def get_transaction(self):
+ """Return an :class:`.AsyncTransaction` representing the current
+ transaction, if any.
+
+ This makes use of the underlying synchronous connection's
+ :meth:`_engine.Connection.get_transaction` method to get the current
+ :class:`_engine.Transaction`, which is then proxied in a new
+ :class:`.AsyncTransaction` object.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ trans = conn.get_transaction()
+ if trans is not None:
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_nested_transaction(self):
+ """Return an :class:`.AsyncTransaction` representing the current
+ nested (savepoint) transaction, if any.
+
+ This makes use of the underlying synchronous connection's
+ :meth:`_engine.Connection.get_nested_transaction` method to get the
+ current :class:`_engine.Transaction`, which is then proxied in a new
+ :class:`.AsyncTransaction` object.
+
+ .. versionadded:: 1.4.0b2
+
+ """
+ conn = self._sync_connection()
+
+ trans = conn.get_nested_transaction()
+ if trans is not None:
+ return AsyncTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ async def execution_options(self, **opt):
+ r"""Set non-SQL options for the connection which take effect
+ during execution.
+
+ This returns this :class:`_asyncio.AsyncConnection` object with
+ the new options added.
+
+ See :meth:`_future.Connection.execution_options` for full details
+ on this method.
+
+ """
+
+ conn = self._sync_connection()
+ c2 = await greenlet_spawn(conn.execution_options, **opt)
+ assert c2 is conn
+ return self
+
+ async def commit(self):
+ """Commit the transaction that is currently in progress.
+
+ This method commits the current transaction if one has been started.
+ If no transaction was started, the method has no effect, assuming
+ the connection is in a non-invalidated state.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.commit)
+
+ async def rollback(self):
+ """Roll back the transaction that is currently in progress.
+
+ This method rolls back the current transaction if one has been started.
+ If no transaction was started, the method has no effect. If a
+ transaction was started and the connection is in an invalidated state,
+ the transaction is cleared using this method.
+
+ A transaction is begun on a :class:`_future.Connection` automatically
+ whenever a statement is first executed, or when the
+ :meth:`_future.Connection.begin` method is called.
+
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.rollback)
+
+ async def close(self):
+ """Close this :class:`_asyncio.AsyncConnection`.
+
+ This has the effect of also rolling back the transaction if one
+ is in place.
+
+ """
+ conn = self._sync_connection()
+ await greenlet_spawn(conn.close)
+
+ async def exec_driver_sql(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a driver-level SQL string and return buffered
+ :class:`_engine.Result`.
+
+ """
+
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn.exec_driver_sql,
+ statement,
+ parameters,
+ execution_options,
+ _require_await=True,
+ )
+
+ return await _ensure_sync_result(result, self.exec_driver_sql)
+
+ async def stream(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ """Execute a statement and return a streaming
+ :class:`_asyncio.AsyncResult` object."""
+
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn._execute_20,
+ statement,
+ parameters,
+ util.EMPTY_DICT.merge_with(
+ execution_options, {"stream_results": True}
+ ),
+ _require_await=True,
+ )
+ if not result.context._is_server_side:
+ # TODO: real exception here
+ assert False, "server side result expected"
+ return AsyncResult(result)
+
+ async def execute(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and return a buffered
+ :class:`_engine.Result`.
+
+ :param object: The statement to be executed. This is always
+ an object that is in both the :class:`_expression.ClauseElement` and
+ :class:`_expression.Executable` hierarchies, including:
+
+ * :class:`_expression.Select`
+ * :class:`_expression.Insert`, :class:`_expression.Update`,
+ :class:`_expression.Delete`
+ * :class:`_expression.TextClause` and
+ :class:`_expression.TextualSelect`
+ * :class:`_schema.DDL` and objects which inherit from
+ :class:`_schema.DDLElement`
+
+ :param parameters: parameters which will be bound into the statement.
+ This may be either a dictionary of parameter names to values,
+ or a mutable sequence (e.g. a list) of dictionaries. When a
+ list of dictionaries is passed, the underlying statement execution
+ will make use of the DBAPI ``cursor.executemany()`` method.
+ When a single dictionary is passed, the DBAPI ``cursor.execute()``
+ method will be used.
+
+ :param execution_options: optional dictionary of execution options,
+ which will be associated with the statement execution. This
+ dictionary can provide a subset of the options that are accepted
+ by :meth:`_future.Connection.execution_options`.
+
+ :return: a :class:`_engine.Result` object.
+
+ """
+ conn = self._sync_connection()
+
+ result = await greenlet_spawn(
+ conn._execute_20,
+ statement,
+ parameters,
+ execution_options,
+ _require_await=True,
+ )
+ return await _ensure_sync_result(result, self.execute)
+
+ async def scalar(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and returns a scalar object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalar` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a scalar Python value representing the first column of the
+ first row returned.
+
+ """
+ result = await self.execute(statement, parameters, execution_options)
+ return result.scalar()
+
+ async def scalars(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement construct and returns a scalar objects.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.Result.scalars` method after invoking the
+ :meth:`_future.Connection.execute` method. Parameters are equivalent.
+
+ :return: a :class:`_engine.ScalarResult` object.
+
+ .. versionadded:: 1.4.24
+
+ """
+ result = await self.execute(statement, parameters, execution_options)
+ return result.scalars()
+
+ async def stream_scalars(
+ self,
+ statement,
+ parameters=None,
+ execution_options=util.EMPTY_DICT,
+ ):
+ r"""Executes a SQL statement and returns a streaming scalar result
+ object.
+
+ This method is shorthand for invoking the
+ :meth:`_engine.AsyncResult.scalars` method after invoking the
+ :meth:`_future.Connection.stream` method. Parameters are equivalent.
+
+ :return: an :class:`_asyncio.AsyncScalarResult` object.
+
+ .. versionadded:: 1.4.24
+
+ """
+ result = await self.stream(statement, parameters, execution_options)
+ return result.scalars()
+
+ async def run_sync(self, fn, *arg, **kw):
+ """Invoke the given sync callable passing self as the first argument.
+
+ This method maintains the asyncio event loop all the way through
+ to the database connection by running the given callable in a
+ specially instrumented greenlet.
+
+ E.g.::
+
+ with async_engine.begin() as conn:
+ await conn.run_sync(metadata.create_all)
+
+ .. note::
+
+ The provided callable is invoked inline within the asyncio event
+ loop, and will block on traditional IO calls. IO within this
+ callable should only call into SQLAlchemy's asyncio database
+ APIs which will be properly adapted to the greenlet context.
+
+ .. seealso::
+
+ :ref:`session_run_sync`
+ """
+
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(fn, conn, *arg, **kw)
+
+ def __await__(self):
+ return self.start().__await__()
+
+ async def __aexit__(self, type_, value, traceback):
+ await asyncio.shield(self.close())
+
+
+@util.create_proxy_methods(
+ Engine,
+ ":class:`_future.Engine`",
+ ":class:`_asyncio.AsyncEngine`",
+ classmethods=[],
+ methods=[
+ "clear_compiled_cache",
+ "update_execution_options",
+ "get_execution_options",
+ ],
+ attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
+)
+class AsyncEngine(ProxyComparable, AsyncConnectable):
+ """An asyncio proxy for a :class:`_engine.Engine`.
+
+ :class:`_asyncio.AsyncEngine` is acquired using the
+ :func:`_asyncio.create_async_engine` function::
+
+ from sqlalchemy.ext.asyncio import create_async_engine
+ engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname")
+
+ .. versionadded:: 1.4
+
+ """ # noqa
+
+ # AsyncEngine is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncEngine that matches this one given only the
+ # "sync" elements.
+ __slots__ = ("sync_engine", "_proxied")
+
+ _connection_cls = AsyncConnection
+
+ _option_cls: type
+
+ class _trans_ctx(StartableContext):
+ def __init__(self, conn):
+ self.conn = conn
+
+ async def start(self, is_ctxmanager=False):
+ await self.conn.start(is_ctxmanager=is_ctxmanager)
+ self.transaction = self.conn.begin()
+ await self.transaction.__aenter__()
+
+ return self.conn
+
+ async def __aexit__(self, type_, value, traceback):
+ async def go():
+ await self.transaction.__aexit__(type_, value, traceback)
+ await self.conn.close()
+
+ await asyncio.shield(go())
+
+ def __init__(self, sync_engine):
+ if not sync_engine.dialect.is_async:
+ raise exc.InvalidRequestError(
+ "The asyncio extension requires an async driver to be used. "
+ f"The loaded {sync_engine.dialect.driver!r} is not async."
+ )
+ self.sync_engine = self._proxied = self._assign_proxied(sync_engine)
+
+ sync_engine: Engine
+ """Reference to the sync-style :class:`_engine.Engine` this
+ :class:`_asyncio.AsyncEngine` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+ """
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ return AsyncEngine(target)
+
+ def begin(self):
+ """Return a context manager which when entered will deliver an
+ :class:`_asyncio.AsyncConnection` with an
+ :class:`_asyncio.AsyncTransaction` established.
+
+ E.g.::
+
+ async with async_engine.begin() as conn:
+ await conn.execute(
+ text("insert into table (x, y, z) values (1, 2, 3)")
+ )
+ await conn.execute(text("my_special_procedure(5)"))
+
+
+ """
+ conn = self.connect()
+ return self._trans_ctx(conn)
+
+ def connect(self):
+ """Return an :class:`_asyncio.AsyncConnection` object.
+
+ The :class:`_asyncio.AsyncConnection` will procure a database
+ connection from the underlying connection pool when it is entered
+ as an async context manager::
+
+ async with async_engine.connect() as conn:
+ result = await conn.execute(select(user_table))
+
+ The :class:`_asyncio.AsyncConnection` may also be started outside of a
+ context manager by invoking its :meth:`_asyncio.AsyncConnection.start`
+ method.
+
+ """
+
+ return self._connection_cls(self)
+
+ async def raw_connection(self):
+ """Return a "raw" DBAPI connection from the connection pool.
+
+ .. seealso::
+
+ :ref:`dbapi_connections`
+
+ """
+ return await greenlet_spawn(self.sync_engine.raw_connection)
+
+ def execution_options(self, **opt):
+ """Return a new :class:`_asyncio.AsyncEngine` that will provide
+ :class:`_asyncio.AsyncConnection` objects with the given execution
+ options.
+
+ Proxied from :meth:`_future.Engine.execution_options`. See that
+ method for details.
+
+ """
+
+ return AsyncEngine(self.sync_engine.execution_options(**opt))
+
+ async def dispose(self):
+ """Dispose of the connection pool used by this
+ :class:`_asyncio.AsyncEngine`.
+
+ This will close all connection pool connections that are
+ **currently checked in**. See the documentation for the underlying
+ :meth:`_future.Engine.dispose` method for further notes.
+
+ .. seealso::
+
+ :meth:`_future.Engine.dispose`
+
+ """
+
+ await greenlet_spawn(self.sync_engine.dispose)
+
+
+class AsyncTransaction(ProxyComparable, StartableContext):
+ """An asyncio proxy for a :class:`_engine.Transaction`."""
+
+ __slots__ = ("connection", "sync_transaction", "nested")
+
+ def __init__(self, connection, nested=False):
+ self.connection = connection # AsyncConnection
+ self.sync_transaction = None # sqlalchemy.engine.Transaction
+ self.nested = nested
+
+ @classmethod
+ def _regenerate_proxy_for_target(cls, target):
+ sync_connection = target.connection
+ sync_transaction = target
+ nested = isinstance(target, NestedTransaction)
+
+ async_connection = AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+ assert async_connection is not None
+
+ obj = cls.__new__(cls)
+ obj.connection = async_connection
+ obj.sync_transaction = obj._assign_proxied(sync_transaction)
+ obj.nested = nested
+ return obj
+
+ def _sync_transaction(self):
+ if not self.sync_transaction:
+ self._raise_for_not_started()
+ return self.sync_transaction
+
+ @property
+ def _proxied(self):
+ return self.sync_transaction
+
+ @property
+ def is_valid(self):
+ return self._sync_transaction().is_valid
+
+ @property
+ def is_active(self):
+ return self._sync_transaction().is_active
+
+ async def close(self):
+ """Close this :class:`.Transaction`.
+
+ If this transaction is the base transaction in a begin/commit
+ nesting, the transaction will rollback(). Otherwise, the
+ method returns.
+
+ This is used to cancel a Transaction without affecting the scope of
+ an enclosing transaction.
+
+ """
+ await greenlet_spawn(self._sync_transaction().close)
+
+ async def rollback(self):
+ """Roll back this :class:`.Transaction`."""
+ await greenlet_spawn(self._sync_transaction().rollback)
+
+ async def commit(self):
+ """Commit this :class:`.Transaction`."""
+
+ await greenlet_spawn(self._sync_transaction().commit)
+
+ async def start(self, is_ctxmanager=False):
+ """Start this :class:`_asyncio.AsyncTransaction` object's context
+ outside of using a Python ``with:`` block.
+
+ """
+
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.connection._sync_connection().begin_nested
+ if self.nested
+ else self.connection._sync_connection().begin
+ )
+ )
+ if is_ctxmanager:
+ self.sync_transaction.__enter__()
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await greenlet_spawn(
+ self._sync_transaction().__exit__, type_, value, traceback
+ )
+
+
+def _get_sync_engine_or_connection(async_engine):
+ if isinstance(async_engine, AsyncConnection):
+ return async_engine.sync_connection
+
+ try:
+ return async_engine.sync_engine
+ except AttributeError as e:
+ raise exc.ArgumentError(
+ "AsyncEngine expected, got %r" % async_engine
+ ) from e
+
+
+@inspection._inspects(AsyncConnection)
+def _no_insp_for_async_conn_yet(subject):
+ raise exc.NoInspectionAvailable(
+ "Inspection on an AsyncConnection is currently not supported. "
+ "Please use ``run_sync`` to pass a callable where it's possible "
+ "to call ``inspect`` on the passed connection.",
+ code="xd3s",
+ )
+
+
+@inspection._inspects(AsyncEngine)
+def _no_insp_for_async_engine_xyet(subject):
+ raise exc.NoInspectionAvailable(
+ "Inspection on an AsyncEngine is currently not supported. "
+ "Please obtain a connection then use ``conn.run_sync`` to pass a "
+ "callable where it's possible to call ``inspect`` on the "
+ "passed connection.",
+ code="xd3s",
+ )
diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py
new file mode 100644
index 0000000..c5d5e01
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/events.py
@@ -0,0 +1,44 @@
+# ext/asyncio/events.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .engine import AsyncConnectable
+from .session import AsyncSession
+from ...engine import events as engine_event
+from ...orm import events as orm_event
+
+
+class AsyncConnectionEvents(engine_event.ConnectionEvents):
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = AsyncConnectable
+
+ @classmethod
+ def _no_async_engine_events(cls):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes."
+ )
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ cls._no_async_engine_events()
+
+
+class AsyncSessionEvents(orm_event.SessionEvents):
+ _target_class_doc = "SomeSession"
+ _dispatch_target = AsyncSession
+
+ @classmethod
+ def _no_async_engine_events(cls):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncSession.sync_session."
+ )
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ cls._no_async_engine_events()
diff --git a/lib/sqlalchemy/ext/asyncio/exc.py b/lib/sqlalchemy/ext/asyncio/exc.py
new file mode 100644
index 0000000..cf0d9a8
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/exc.py
@@ -0,0 +1,21 @@
+# ext/asyncio/exc.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from ... import exc
+
+
+class AsyncMethodRequired(exc.InvalidRequestError):
+ """an API can't be used because its result would not be
+ compatible with async"""
+
+
+class AsyncContextNotStarted(exc.InvalidRequestError):
+ """a startable context manager has not been started."""
+
+
+class AsyncContextAlreadyStarted(exc.InvalidRequestError):
+ """a startable context manager is already started."""
diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py
new file mode 100644
index 0000000..a77b6a8
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/result.py
@@ -0,0 +1,671 @@
+# ext/asyncio/result.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import operator
+
+from . import exc as async_exc
+from ...engine.result import _NO_ROW
+from ...engine.result import FilterResult
+from ...engine.result import FrozenResult
+from ...engine.result import MergedResult
+from ...sql.base import _generative
+from ...util.concurrency import greenlet_spawn
+
+
+class AsyncCommon(FilterResult):
+ async def close(self):
+ """Close this result."""
+
+ await greenlet_spawn(self._real_result.close)
+
+
+class AsyncResult(AsyncCommon):
+ """An asyncio wrapper around a :class:`_result.Result` object.
+
+ The :class:`_asyncio.AsyncResult` only applies to statement executions that
+ use a server-side cursor. It is returned only from the
+ :meth:`_asyncio.AsyncConnection.stream` and
+ :meth:`_asyncio.AsyncSession.stream` methods.
+
+ .. note:: As is the case with :class:`_engine.Result`, this object is
+ used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`,
+ which can yield instances of ORM mapped objects either individually or
+ within tuple-like rows. Note that these result objects do not
+ deduplicate instances or rows automatically as is the case with the
+ legacy :class:`_orm.Query` object. For in-Python de-duplication of
+ instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier
+ method.
+
+ .. versionadded:: 1.4
+
+ """
+
+ def __init__(self, real_result):
+ self._real_result = real_result
+
+ self._metadata = real_result._metadata
+ self._unique_filter_state = real_result._unique_filter_state
+
+ # BaseCursorResult pre-generates the "_row_getter". Use that
+ # if available rather than building a second one
+ if "_row_getter" in real_result.__dict__:
+ self._set_memoized_attribute(
+ "_row_getter", real_result.__dict__["_row_getter"]
+ )
+
+ def keys(self):
+ """Return the :meth:`_engine.Result.keys` collection from the
+ underlying :class:`_engine.Result`.
+
+ """
+ return self._metadata.keys
+
+ @_generative
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncResult`.
+
+ Refer to :meth:`_engine.Result.unique` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+
+ """
+ self._unique_filter_state = (set(), strategy)
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row.
+
+ Refer to :meth:`_engine.Result.columns` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+
+ """
+ return self._column_slices(col_expressions)
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of rows of the size given.
+
+ An async iterator is returned::
+
+ async def scroll_results(connection):
+ result = await connection.stream(select(users_table))
+
+ async for partition in result.partitions(100):
+ print("list of rows: %s" % partition)
+
+ .. seealso::
+
+ :meth:`_engine.Result.partitions`
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchone(self):
+ """Fetch one row.
+
+ When all rows are exhausted, returns None.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch the first row of a result only, use the
+ :meth:`_engine.Result.first` method. To iterate through all
+ rows, iterate the :class:`_engine.Result` object directly.
+
+ :return: a :class:`.Row` object if no filters are applied, or None
+ if no rows remain.
+
+ """
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ async def fetchmany(self, size=None):
+ """Fetch many rows.
+
+ When all rows are exhausted, returns an empty list.
+
+ This method is provided for backwards compatibility with
+ SQLAlchemy 1.x.x.
+
+ To fetch rows in groups, use the
+ :meth:`._asyncio.AsyncResult.partitions` method.
+
+ :return: a list of :class:`.Row` objects.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.partitions`
+
+ """
+
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all rows in a list.
+
+ Closes the result set after invocation. Subsequent invocations
+ will return an empty list.
+
+ :return: a list of :class:`.Row` objects.
+
+ """
+
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first row or None if no row is present.
+
+ Closes the result set and discards remaining rows.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default. To
+ return exactly one single scalar value, that is, the first column of
+ the first row, use the :meth:`_asyncio.AsyncResult.scalar` method,
+ or combine :meth:`_asyncio.AsyncResult.scalars` and
+ :meth:`_asyncio.AsyncResult.first`.
+
+ :return: a :class:`.Row` object, or None
+ if no rows remain.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.scalar`
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one result or raise an exception.
+
+ Returns ``None`` if the result has no rows.
+ Raises :class:`.MultipleResultsFound`
+ if multiple rows are returned.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row` or None if no row is available.
+
+ :raises: :class:`.MultipleResultsFound`
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.first`
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def scalar_one(self):
+ """Return exactly one scalar result or raise an exception.
+
+ This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
+ then :meth:`_asyncio.AsyncResult.one`.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.one`
+
+ :meth:`_asyncio.AsyncResult.scalars`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, True)
+
+ async def scalar_one_or_none(self):
+ """Return exactly one or no scalar result.
+
+ This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and
+ then :meth:`_asyncio.AsyncResult.one_or_none`.
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.one_or_none`
+
+ :meth:`_asyncio.AsyncResult.scalars`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, True)
+
+ async def one(self):
+ """Return exactly one row or raise an exception.
+
+ Raises :class:`.NoResultFound` if the result returns no
+ rows, or :class:`.MultipleResultsFound` if multiple rows
+ would be returned.
+
+ .. note:: This method returns one **row**, e.g. tuple, by default.
+ To return exactly one single scalar value, that is, the first
+ column of the first row, use the
+ :meth:`_asyncio.AsyncResult.scalar_one` method, or combine
+ :meth:`_asyncio.AsyncResult.scalars` and
+ :meth:`_asyncio.AsyncResult.one`.
+
+ .. versionadded:: 1.4
+
+ :return: The first :class:`.Row`.
+
+ :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound`
+
+ .. seealso::
+
+ :meth:`_asyncio.AsyncResult.first`
+
+ :meth:`_asyncio.AsyncResult.one_or_none`
+
+ :meth:`_asyncio.AsyncResult.scalar_one`
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+ async def scalar(self):
+ """Fetch the first column of the first row, and close the result set.
+
+ Returns None if there are no rows to fetch.
+
+ No validation is performed to test if additional rows remain.
+
+ After calling this method, the object is fully closed,
+ e.g. the :meth:`_engine.CursorResult.close`
+ method will have been called.
+
+ :return: a Python scalar value , or None if no rows remain.
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, True)
+
+ async def freeze(self):
+ """Return a callable object that will produce copies of this
+ :class:`_asyncio.AsyncResult` when invoked.
+
+ The callable object returned is an instance of
+ :class:`_engine.FrozenResult`.
+
+ This is used for result set caching. The method must be called
+ on the result when it has been unconsumed, and calling the method
+ will consume the result fully. When the :class:`_engine.FrozenResult`
+ is retrieved from a cache, it can be called any number of times where
+ it will produce a new :class:`_engine.Result` object each time
+ against its stored set of rows.
+
+ .. seealso::
+
+ :ref:`do_orm_execute_re_executing` - example usage within the
+ ORM to implement a result-set cache.
+
+ """
+
+ return await greenlet_spawn(FrozenResult, self)
+
+ def merge(self, *others):
+ """Merge this :class:`_asyncio.AsyncResult` with other compatible
+ result objects.
+
+ The object returned is an instance of :class:`_engine.MergedResult`,
+ which will be composed of iterators from the given result
+ objects.
+
+ The new result will use the metadata from this result object.
+ The subsequent result objects must be against an identical
+ set of result / cursor metadata, otherwise the behavior is
+ undefined.
+
+ """
+ return MergedResult(self._metadata, (self,) + others)
+
+ def scalars(self, index=0):
+ """Return an :class:`_asyncio.AsyncScalarResult` filtering object which
+ will return single elements rather than :class:`_row.Row` objects.
+
+ Refer to :meth:`_result.Result.scalars` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ :param index: integer or row key indicating the column to be fetched
+ from each row, defaults to ``0`` indicating the first column.
+
+ :return: a new :class:`_asyncio.AsyncScalarResult` filtering object
+ referring to this :class:`_asyncio.AsyncResult` object.
+
+ """
+ return AsyncScalarResult(self._real_result, index)
+
+ def mappings(self):
+ """Apply a mappings filter to returned rows, returning an instance of
+ :class:`_asyncio.AsyncMappingResult`.
+
+ When this filter is applied, fetching rows will return
+ :class:`.RowMapping` objects instead of :class:`.Row` objects.
+
+ Refer to :meth:`_result.Result.mappings` in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ :return: a new :class:`_asyncio.AsyncMappingResult` filtering object
+ referring to the underlying :class:`_result.Result` object.
+
+ """
+
+ return AsyncMappingResult(self._real_result)
+
+
+class AsyncScalarResult(AsyncCommon):
+ """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values
+ rather than :class:`_row.Row` values.
+
+ The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the
+ :meth:`_asyncio.AsyncResult.scalars` method.
+
+ Refer to the :class:`_result.ScalarResult` object in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _generate_rows = False
+
+ def __init__(self, real_result, index):
+ self._real_result = real_result
+
+ if real_result._source_supports_scalars:
+ self._metadata = real_result._metadata
+ self._post_creational_filter = None
+ else:
+ self._metadata = real_result._metadata._reduce([index])
+ self._post_creational_filter = operator.itemgetter(0)
+
+ self._unique_filter_state = real_result._unique_filter_state
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncScalarResult`.
+
+ See :meth:`_asyncio.AsyncResult.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchall(self):
+ """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method."""
+
+ return await greenlet_spawn(self._allrows)
+
+ async def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.all` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.first` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one` except that
+ scalar values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+class AsyncMappingResult(AsyncCommon):
+ """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary
+ values rather than :class:`_engine.Row` values.
+
+ The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the
+ :meth:`_asyncio.AsyncResult.mappings` method.
+
+ Refer to the :class:`_result.MappingResult` object in the synchronous
+ SQLAlchemy API for a complete behavioral description.
+
+ .. versionadded:: 1.4
+
+ """
+
+ _generate_rows = True
+
+ _post_creational_filter = operator.attrgetter("_mapping")
+
+ def __init__(self, result):
+ self._real_result = result
+ self._unique_filter_state = result._unique_filter_state
+ self._metadata = result._metadata
+ if result._source_supports_scalars:
+ self._metadata = self._metadata._reduce([0])
+
+ def keys(self):
+ """Return an iterable view which yields the string keys that would
+ be represented by each :class:`.Row`.
+
+ The view also can be tested for key containment using the Python
+ ``in`` operator, which will test both for the string keys represented
+ in the view, as well as for alternate keys such as column objects.
+
+ .. versionchanged:: 1.4 a key view object is returned rather than a
+ plain list.
+
+
+ """
+ return self._metadata.keys
+
+ def unique(self, strategy=None):
+ """Apply unique filtering to the objects returned by this
+ :class:`_asyncio.AsyncMappingResult`.
+
+ See :meth:`_asyncio.AsyncResult.unique` for usage details.
+
+ """
+ self._unique_filter_state = (set(), strategy)
+ return self
+
+ def columns(self, *col_expressions):
+ r"""Establish the columns that should be returned in each row."""
+ return self._column_slices(col_expressions)
+
+ async def partitions(self, size=None):
+ """Iterate through sub-lists of elements of the size given.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ getter = self._manyrow_getter
+
+ while True:
+ partition = await greenlet_spawn(getter, self, size)
+ if partition:
+ yield partition
+ else:
+ break
+
+ async def fetchall(self):
+ """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method."""
+
+ return await greenlet_spawn(self._allrows)
+
+ async def fetchone(self):
+ """Fetch one object.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ return None
+ else:
+ return row
+
+ async def fetchmany(self, size=None):
+ """Fetch many objects.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return await greenlet_spawn(self._manyrow_getter, self, size)
+
+ async def all(self):
+ """Return all scalar values in a list.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.all` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+
+ return await greenlet_spawn(self._allrows)
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ row = await greenlet_spawn(self._onerow_getter, self)
+ if row is _NO_ROW:
+ raise StopAsyncIteration()
+ else:
+ return row
+
+ async def first(self):
+ """Fetch the first object or None if no object is present.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.first` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+
+ """
+ return await greenlet_spawn(self._only_one_row, False, False, False)
+
+ async def one_or_none(self):
+ """Return at most one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, False, False)
+
+ async def one(self):
+ """Return exactly one object or raise an exception.
+
+ Equivalent to :meth:`_asyncio.AsyncResult.one` except that
+ mapping values, rather than :class:`_result.Row` objects,
+ are returned.
+
+ """
+ return await greenlet_spawn(self._only_one_row, True, True, False)
+
+
+async def _ensure_sync_result(result, calling_method):
+ if not result._is_cursor:
+ cursor_result = getattr(result, "raw", None)
+ else:
+ cursor_result = result
+ if cursor_result and cursor_result.context._is_server_side:
+ await greenlet_spawn(cursor_result.close)
+ raise async_exc.AsyncMethodRequired(
+ "Can't use the %s.%s() method with a "
+ "server-side cursor. "
+ "Use the %s.stream() method for an async "
+ "streaming result set."
+ % (
+ calling_method.__self__.__class__.__name__,
+ calling_method.__name__,
+ calling_method.__self__.__class__.__name__,
+ )
+ )
+ return result
diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py
new file mode 100644
index 0000000..8eca8c5
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/scoping.py
@@ -0,0 +1,107 @@
+# ext/asyncio/scoping.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .session import AsyncSession
+from ...orm.scoping import ScopedSessionMixin
+from ...util import create_proxy_methods
+from ...util import ScopedRegistry
+
+
+@create_proxy_methods(
+ AsyncSession,
+ ":class:`_asyncio.AsyncSession`",
+ ":class:`_asyncio.scoping.async_scoped_session`",
+ classmethods=["close_all", "object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get",
+ "get_bind",
+ "is_modified",
+ "invalidate",
+ "merge",
+ "refresh",
+ "rollback",
+ "scalar",
+ "scalars",
+ "stream",
+ "stream_scalars",
+ ],
+ attributes=[
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
+class async_scoped_session(ScopedSessionMixin):
+ """Provides scoped management of :class:`.AsyncSession` objects.
+
+ See the section :ref:`asyncio_scoped_session` for usage details.
+
+ .. versionadded:: 1.4.19
+
+
+ """
+
+ _support_async = True
+
+ def __init__(self, session_factory, scopefunc):
+ """Construct a new :class:`_asyncio.async_scoped_session`.
+
+ :param session_factory: a factory to create new :class:`_asyncio.AsyncSession`
+ instances. This is usually, but not necessarily, an instance
+ of :class:`_orm.sessionmaker` which itself was passed the
+ :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_`
+ parameter::
+
+ async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession)
+ AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task)
+
+ :param scopefunc: function which defines
+ the current scope. A function such as ``asyncio.current_task``
+ may be useful here.
+
+ """ # noqa: E501
+
+ self.session_factory = session_factory
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+
+ @property
+ def _proxied(self):
+ return self.registry()
+
+ async def remove(self):
+ """Dispose of the current :class:`.AsyncSession`, if present.
+
+ Different from scoped_session's remove method, this method would use
+ await to wait for the close method of AsyncSession.
+
+ """
+
+ if self.registry.has():
+ await self.registry().close()
+ self.registry.clear()
diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py
new file mode 100644
index 0000000..378cbcb
--- /dev/null
+++ b/lib/sqlalchemy/ext/asyncio/session.py
@@ -0,0 +1,759 @@
+# ext/asyncio/session.py
+# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import asyncio
+
+from . import engine
+from . import result as _result
+from .base import ReversibleProxy
+from .base import StartableContext
+from .result import _ensure_sync_result
+from ... import util
+from ...orm import object_session
+from ...orm import Session
+from ...orm import state as _instance_state
+from ...util.concurrency import greenlet_spawn
+
+_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True})
+_STREAM_OPTIONS = util.immutabledict({"stream_results": True})
+
+
+@util.create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_asyncio.AsyncSession`",
+ classmethods=["object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "is_modified",
+ "in_transaction",
+ "in_nested_transaction",
+ ],
+ attributes=[
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
+class AsyncSession(ReversibleProxy):
+ """Asyncio version of :class:`_orm.Session`.
+
+ The :class:`_asyncio.AsyncSession` is a proxy for a traditional
+ :class:`_orm.Session` instance.
+
+ .. versionadded:: 1.4
+
+ To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
+ implementations, see the
+ :paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
+
+
+ """
+
+ _is_asyncio = True
+
+ dispatch = None
+
+ def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
+ r"""Construct a new :class:`_asyncio.AsyncSession`.
+
+ All parameters other than ``sync_session_class`` are passed to the
+ ``sync_session_class`` callable directly to instantiate a new
+ :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
+ parameter documentation.
+
+ :param sync_session_class:
+ A :class:`_orm.Session` subclass or other callable which will be used
+ to construct the :class:`_orm.Session` which will be proxied. This
+ parameter may be used to provide custom :class:`_orm.Session`
+ subclasses. Defaults to the
+ :attr:`_asyncio.AsyncSession.sync_session_class` class-level
+ attribute.
+
+ .. versionadded:: 1.4.24
+
+ """
+ kw["future"] = True
+ if bind:
+ self.bind = bind
+ bind = engine._get_sync_engine_or_connection(bind)
+
+ if binds:
+ self.binds = binds
+ binds = {
+ key: engine._get_sync_engine_or_connection(b)
+ for key, b in binds.items()
+ }
+
+ if sync_session_class:
+ self.sync_session_class = sync_session_class
+
+ self.sync_session = self._proxied = self._assign_proxied(
+ self.sync_session_class(bind=bind, binds=binds, **kw)
+ )
+
+ sync_session_class = Session
+ """The class or callable that provides the
+ underlying :class:`_orm.Session` instance for a particular
+ :class:`_asyncio.AsyncSession`.
+
+ At the class level, this attribute is the default value for the
+ :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
+ subclasses of :class:`_asyncio.AsyncSession` can override this.
+
+ At the instance level, this attribute indicates the current class or
+ callable that was used to provide the :class:`_orm.Session` instance for
+ this :class:`_asyncio.AsyncSession` instance.
+
+ .. versionadded:: 1.4.24
+
+ """
+
+ sync_session: Session
+ """Reference to the underlying :class:`_orm.Session` this
+ :class:`_asyncio.AsyncSession` proxies requests towards.
+
+ This instance can be used as an event target.
+
+ .. seealso::
+
+ :ref:`asyncio_events`
+
+ """
+
+ async def refresh(
+ self, instance, attribute_names=None, with_for_update=None
+ ):
+ """Expire and refresh the attributes on the given instance.
+
+ A query will be issued to the database and all attributes will be
+ refreshed with their current database value.
+
+ This is the async version of the :meth:`_orm.Session.refresh` method.
+ See that method for a complete description of all options.
+
+ .. seealso::
+
+ :meth:`_orm.Session.refresh` - main documentation for refresh
+
+ """
+
+ return await greenlet_spawn(
+ self.sync_session.refresh,
+ instance,
+ attribute_names=attribute_names,
+ with_for_update=with_for_update,
+ )
+
+ async def run_sync(self, fn, *arg, **kw):
+ """Invoke the given sync callable passing sync self as the first
+ argument.
+
+ This method maintains the asyncio event loop all the way through
+ to the database connection by running the given callable in a
+ specially instrumented greenlet.
+
+ E.g.::
+
+ with AsyncSession(async_engine) as session:
+ await session.run_sync(some_business_method)
+
+ .. note::
+
+ The provided callable is invoked inline within the asyncio event
+ loop, and will block on traditional IO calls. IO within this
+ callable should only call into SQLAlchemy's asyncio database
+ APIs which will be properly adapted to the greenlet context.
+
+ .. seealso::
+
+ :ref:`session_run_sync`
+ """
+
+ return await greenlet_spawn(fn, self.sync_session, *arg, **kw)
+
+ async def execute(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a buffered
+ :class:`_engine.Result` object.
+
+ .. seealso::
+
+ :meth:`_orm.Session.execute` - main documentation for execute
+
+ """
+
+ if execution_options:
+ execution_options = util.immutabledict(execution_options).union(
+ _EXECUTE_OPTIONS
+ )
+ else:
+ execution_options = _EXECUTE_OPTIONS
+
+ result = await greenlet_spawn(
+ self.sync_session.execute,
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return await _ensure_sync_result(result, self.execute)
+
+ async def scalar(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a scalar result.
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalar` - main documentation for scalar
+
+ """
+
+ result = await self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalar()
+
+ async def scalars(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return scalar results.
+
+ :return: a :class:`_result.ScalarResult` object
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalars` - main documentation for scalars
+
+ :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version
+
+ """
+
+ result = await self.execute(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalars()
+
+ async def get(
+ self,
+ entity,
+ ident,
+ options=None,
+ populate_existing=False,
+ with_for_update=None,
+ identity_token=None,
+ ):
+ """Return an instance based on the given primary key identifier,
+ or ``None`` if not found.
+
+ .. seealso::
+
+ :meth:`_orm.Session.get` - main documentation for get
+
+
+ """
+ return await greenlet_spawn(
+ self.sync_session.get,
+ entity,
+ ident,
+ options=options,
+ populate_existing=populate_existing,
+ with_for_update=with_for_update,
+ identity_token=identity_token,
+ )
+
+ async def stream(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a streaming
+ :class:`_asyncio.AsyncResult` object.
+
+ """
+
+ if execution_options:
+ execution_options = util.immutabledict(execution_options).union(
+ _STREAM_OPTIONS
+ )
+ else:
+ execution_options = _STREAM_OPTIONS
+
+ result = await greenlet_spawn(
+ self.sync_session.execute,
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return _result.AsyncResult(result)
+
+ async def stream_scalars(
+ self,
+ statement,
+ params=None,
+ execution_options=util.EMPTY_DICT,
+ bind_arguments=None,
+ **kw
+ ):
+ """Execute a statement and return a stream of scalar results.
+
+ :return: an :class:`_asyncio.AsyncScalarResult` object
+
+ .. versionadded:: 1.4.24
+
+ .. seealso::
+
+ :meth:`_orm.Session.scalars` - main documentation for scalars
+
+ :meth:`_asyncio.AsyncSession.scalars` - non streaming version
+
+ """
+
+ result = await self.stream(
+ statement,
+ params=params,
+ execution_options=execution_options,
+ bind_arguments=bind_arguments,
+ **kw
+ )
+ return result.scalars()
+
+ async def delete(self, instance):
+ """Mark an instance as deleted.
+
+ The database delete operation occurs upon ``flush()``.
+
+ As this operation may need to cascade along unloaded relationships,
+ it is awaitable to allow for those queries to take place.
+
+ .. seealso::
+
+ :meth:`_orm.Session.delete` - main documentation for delete
+
+ """
+ return await greenlet_spawn(self.sync_session.delete, instance)
+
+ async def merge(self, instance, load=True, options=None):
+ """Copy the state of a given instance into a corresponding instance
+ within this :class:`_asyncio.AsyncSession`.
+
+ .. seealso::
+
+ :meth:`_orm.Session.merge` - main documentation for merge
+
+ """
+ return await greenlet_spawn(
+ self.sync_session.merge, instance, load=load, options=options
+ )
+
+ async def flush(self, objects=None):
+ """Flush all the object changes to the database.
+
+ .. seealso::
+
+ :meth:`_orm.Session.flush` - main documentation for flush
+
+ """
+ await greenlet_spawn(self.sync_session.flush, objects=objects)
+
+ def get_transaction(self):
+ """Return the current root transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ trans = self.sync_session.get_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_nested_transaction(self):
+ """Return the current nested transaction in progress, if any.
+
+ :return: an :class:`_asyncio.AsyncSessionTransaction` object, or
+ ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ trans = self.sync_session.get_nested_transaction()
+ if trans is not None:
+ return AsyncSessionTransaction._retrieve_proxy_for_target(trans)
+ else:
+ return None
+
+ def get_bind(self, mapper=None, clause=None, bind=None, **kw):
+ """Return a "bind" to which the synchronous proxied :class:`_orm.Session`
+ is bound.
+
+ Unlike the :meth:`_orm.Session.get_bind` method, this method is
+ currently **not** used by this :class:`.AsyncSession` in any way
+ in order to resolve engines for requests.
+
+ .. note::
+
+ This method proxies directly to the :meth:`_orm.Session.get_bind`
+ method, however is currently **not** useful as an override target,
+ in contrast to that of the :meth:`_orm.Session.get_bind` method.
+ The example below illustrates how to implement custom
+ :meth:`_orm.Session.get_bind` schemes that work with
+ :class:`.AsyncSession` and :class:`.AsyncEngine`.
+
+ The pattern introduced at :ref:`session_custom_partitioning`
+ illustrates how to apply a custom bind-lookup scheme to a
+ :class:`_orm.Session` given a set of :class:`_engine.Engine` objects.
+ To apply a corresponding :meth:`_orm.Session.get_bind` implementation
+ for use with a :class:`.AsyncSession` and :class:`.AsyncEngine`
+ objects, continue to subclass :class:`_orm.Session` and apply it to
+ :class:`.AsyncSession` using
+ :paramref:`.AsyncSession.sync_session_class`. The inner method must
+ continue to return :class:`_engine.Engine` instances, which can be
+ acquired from a :class:`_asyncio.AsyncEngine` using the
+ :attr:`_asyncio.AsyncEngine.sync_engine` attribute::
+
+ # using example from "Custom Vertical Partitioning"
+
+
+ import random
+
+ from sqlalchemy.ext.asyncio import AsyncSession
+ from sqlalchemy.ext.asyncio import create_async_engine
+ from sqlalchemy.orm import Session, sessionmaker
+
+ # construct async engines w/ async drivers
+ engines = {
+ 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"),
+ 'other':create_async_engine("sqlite+aiosqlite:///other.db"),
+ 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"),
+ 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"),
+ }
+
+ class RoutingSession(Session):
+ def get_bind(self, mapper=None, clause=None, **kw):
+ # within get_bind(), return sync engines
+ if mapper and issubclass(mapper.class_, MyOtherClass):
+ return engines['other'].sync_engine
+ elif self._flushing or isinstance(clause, (Update, Delete)):
+ return engines['leader'].sync_engine
+ else:
+ return engines[
+ random.choice(['follower1','follower2'])
+ ].sync_engine
+
+ # apply to AsyncSession using sync_session_class
+ AsyncSessionMaker = sessionmaker(
+ class_=AsyncSession,
+ sync_session_class=RoutingSession
+ )
+
+ The :meth:`_orm.Session.get_bind` method is called in a non-asyncio,
+ implicitly non-blocking context in the same manner as ORM event hooks
+ and functions that are invoked via :meth:`.AsyncSession.run_sync`, so
+ routines that wish to run SQL commands inside of
+ :meth:`_orm.Session.get_bind` can continue to do so using
+ blocking-style code, which will be translated to implicitly async calls
+ at the point of invoking IO on the database drivers.
+
+ """ # noqa: E501
+
+ return self.sync_session.get_bind(
+ mapper=mapper, clause=clause, bind=bind, **kw
+ )
+
+ async def connection(self, **kw):
+ r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to
+ this :class:`.Session` object's transactional state.
+
+ This method may also be used to establish execution options for the
+ database connection used by the current transaction.
+
+ .. versionadded:: 1.4.24 Added \**kw arguments which are passed
+ through to the underlying :meth:`_orm.Session.connection` method.
+
+ .. seealso::
+
+ :meth:`_orm.Session.connection` - main documentation for
+ "connection"
+
+ """
+
+ sync_connection = await greenlet_spawn(
+ self.sync_session.connection, **kw
+ )
+ return engine.AsyncConnection._retrieve_proxy_for_target(
+ sync_connection
+ )
+
+ def begin(self, **kw):
+ """Return an :class:`_asyncio.AsyncSessionTransaction` object.
+
+ The underlying :class:`_orm.Session` will perform the
+ "begin" action when the :class:`_asyncio.AsyncSessionTransaction`
+ object is entered::
+
+ async with async_session.begin():
+ # .. ORM transaction is begun
+
+ Note that database IO will not normally occur when the session-level
+ transaction is begun, as database transactions begin on an
+ on-demand basis. However, the begin block is async to accommodate
+ for a :meth:`_orm.SessionEvents.after_transaction_create`
+ event hook that may perform IO.
+
+ For a general description of ORM begin, see
+ :meth:`_orm.Session.begin`.
+
+ """
+
+ return AsyncSessionTransaction(self)
+
+ def begin_nested(self, **kw):
+ """Return an :class:`_asyncio.AsyncSessionTransaction` object
+ which will begin a "nested" transaction, e.g. SAVEPOINT.
+
+ Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`.
+
+ For a general description of ORM begin nested, see
+ :meth:`_orm.Session.begin_nested`.
+
+ """
+
+ return AsyncSessionTransaction(self, nested=True)
+
+ async def rollback(self):
+ """Rollback the current transaction in progress."""
+ return await greenlet_spawn(self.sync_session.rollback)
+
+ async def commit(self):
+ """Commit the current transaction in progress."""
+ return await greenlet_spawn(self.sync_session.commit)
+
+ async def close(self):
+ """Close out the transactional resources and ORM objects used by this
+ :class:`_asyncio.AsyncSession`.
+
+ This expunges all ORM objects associated with this
+ :class:`_asyncio.AsyncSession`, ends any transaction in progress and
+ :term:`releases` any :class:`_asyncio.AsyncConnection` objects which
+ this :class:`_asyncio.AsyncSession` itself has checked out from
+ associated :class:`_asyncio.AsyncEngine` objects. The operation then
+ leaves the :class:`_asyncio.AsyncSession` in a state which it may be
+ used again.
+
+ .. tip::
+
+ The :meth:`_asyncio.AsyncSession.close` method **does not prevent
+ the Session from being used again**. The
+ :class:`_asyncio.AsyncSession` itself does not actually have a
+ distinct "closed" state; it merely means the
+ :class:`_asyncio.AsyncSession` will release all database
+ connections and ORM objects.
+
+
+ .. seealso::
+
+ :ref:`session_closing` - detail on the semantics of
+ :meth:`_asyncio.AsyncSession.close`
+
+ """
+ await greenlet_spawn(self.sync_session.close)
+
+ async def invalidate(self):
+ """Close this Session, using connection invalidation.
+
+ For a complete description, see :meth:`_orm.Session.invalidate`.
+ """
+ return await greenlet_spawn(self.sync_session.invalidate)
+
+ @classmethod
+ async def close_all(self):
+ """Close all :class:`_asyncio.AsyncSession` sessions."""
+ return await greenlet_spawn(self.sync_session.close_all)
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await asyncio.shield(self.close())
+
+ def _maker_context_manager(self):
+ # no @contextlib.asynccontextmanager until python3.7, gr
+ return _AsyncSessionContextManager(self)
+
+
+class _AsyncSessionContextManager:
+ def __init__(self, async_session):
+ self.async_session = async_session
+
+ async def __aenter__(self):
+ self.trans = self.async_session.begin()
+ await self.trans.__aenter__()
+ return self.async_session
+
+ async def __aexit__(self, type_, value, traceback):
+ async def go():
+ await self.trans.__aexit__(type_, value, traceback)
+ await self.async_session.__aexit__(type_, value, traceback)
+
+ await asyncio.shield(go())
+
+
+class AsyncSessionTransaction(ReversibleProxy, StartableContext):
+ """A wrapper for the ORM :class:`_orm.SessionTransaction` object.
+
+ This object is provided so that a transaction-holding object
+ for the :meth:`_asyncio.AsyncSession.begin` may be returned.
+
+ The object supports both explicit calls to
+ :meth:`_asyncio.AsyncSessionTransaction.commit` and
+ :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an
+ async context manager.
+
+
+ .. versionadded:: 1.4
+
+ """
+
+ __slots__ = ("session", "sync_transaction", "nested")
+
+ def __init__(self, session, nested=False):
+ self.session = session
+ self.nested = nested
+ self.sync_transaction = None
+
+ @property
+ def is_active(self):
+ return (
+ self._sync_transaction() is not None
+ and self._sync_transaction().is_active
+ )
+
+ def _sync_transaction(self):
+ if not self.sync_transaction:
+ self._raise_for_not_started()
+ return self.sync_transaction
+
+ async def rollback(self):
+ """Roll back this :class:`_asyncio.AsyncTransaction`."""
+ await greenlet_spawn(self._sync_transaction().rollback)
+
+ async def commit(self):
+ """Commit this :class:`_asyncio.AsyncTransaction`."""
+
+ await greenlet_spawn(self._sync_transaction().commit)
+
+ async def start(self, is_ctxmanager=False):
+ self.sync_transaction = self._assign_proxied(
+ await greenlet_spawn(
+ self.session.sync_session.begin_nested
+ if self.nested
+ else self.session.sync_session.begin
+ )
+ )
+ if is_ctxmanager:
+ self.sync_transaction.__enter__()
+ return self
+
+ async def __aexit__(self, type_, value, traceback):
+ await greenlet_spawn(
+ self._sync_transaction().__exit__, type_, value, traceback
+ )
+
+
+def async_object_session(instance):
+ """Return the :class:`_asyncio.AsyncSession` to which the given instance
+ belongs.
+
+ This function makes use of the sync-API function
+ :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which
+ refers to the given instance, and from there links it to the original
+ :class:`_asyncio.AsyncSession`.
+
+ If the :class:`_asyncio.AsyncSession` has been garbage collected, the
+ return value is ``None``.
+
+ This functionality is also available from the
+ :attr:`_orm.InstanceState.async_session` accessor.
+
+ :param instance: an ORM mapped instance
+ :return: an :class:`_asyncio.AsyncSession` object, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+
+ session = object_session(instance)
+ if session is not None:
+ return async_session(session)
+ else:
+ return None
+
+
+def async_session(session):
+ """Return the :class:`_asyncio.AsyncSession` which is proxying the given
+ :class:`_orm.Session` object, if any.
+
+ :param session: a :class:`_orm.Session` instance.
+ :return: a :class:`_asyncio.AsyncSession` instance, or ``None``.
+
+ .. versionadded:: 1.4.18
+
+ """
+ return AsyncSession._retrieve_proxy_for_target(session, regenerate=False)
+
+
+_instance_state._async_provider = async_session
diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py
new file mode 100644
index 0000000..a5d7267
--- /dev/null
+++ b/lib/sqlalchemy/ext/automap.py
@@ -0,0 +1,1234 @@
+# ext/automap.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system
+which automatically generates mapped classes and relationships from a database
+schema, typically though not necessarily one which is reflected.
+
+It is hoped that the :class:`.AutomapBase` system provides a quick
+and modernized solution to the problem that the very famous
+`SQLSoup <https://sqlsoup.readthedocs.io/en/latest/>`_
+also tries to solve, that of generating a quick and rudimentary object
+model from an existing database on the fly. By addressing the issue strictly
+at the mapper configuration level, and integrating fully with existing
+Declarative class techniques, :class:`.AutomapBase` seeks to provide
+a well-integrated approach to the issue of expediently auto-generating ad-hoc
+mappings.
+
+.. tip:: The :ref:`automap_toplevel` extension is geared towards a
+ "zero declaration" approach, where a complete ORM model including classes
+ and pre-named relationships can be generated on the fly from a database
+ schema. For applications that still want to use explicit class declarations
+ including explicit relationship definitions in conjunction with reflection
+ of tables, the :class:`.DeferredReflection` class, described at
+ :ref:`orm_declarative_reflected_deferred_reflection`, is a better choice.
+
+
+
+Basic Use
+=========
+
+The simplest usage is to reflect an existing database into a new model.
+We create a new :class:`.AutomapBase` class in a similar manner as to how
+we create a declarative base class, using :func:`.automap_base`.
+We then call :meth:`.AutomapBase.prepare` on the resulting base class,
+asking it to reflect the schema and produce mappings::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy.orm import Session
+ from sqlalchemy import create_engine
+
+ Base = automap_base()
+
+ # engine, suppose it has two tables 'user' and 'address' set up
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ # reflect the tables
+ Base.prepare(autoload_with=engine)
+
+ # mapped classes are now created with names by default
+ # matching that of the table name.
+ User = Base.classes.user
+ Address = Base.classes.address
+
+ session = Session(engine)
+
+ # rudimentary relationships are produced
+ session.add(Address(email_address="foo@bar.com", user=User(name="foo")))
+ session.commit()
+
+ # collection-based relationships are by default named
+ # "<classname>_collection"
+ print (u1.address_collection)
+
+Above, calling :meth:`.AutomapBase.prepare` while passing along the
+:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the
+:meth:`_schema.MetaData.reflect`
+method will be called on this declarative base
+classes' :class:`_schema.MetaData` collection; then, each **viable**
+:class:`_schema.Table` within the :class:`_schema.MetaData`
+will get a new mapped class
+generated automatically. The :class:`_schema.ForeignKeyConstraint`
+objects which
+link the various tables together will be used to produce new, bidirectional
+:func:`_orm.relationship` objects between classes.
+The classes and relationships
+follow along a default naming scheme that we can customize. At this point,
+our basic mapping consisting of related ``User`` and ``Address`` classes is
+ready to use in the traditional way.
+
+.. note:: By **viable**, we mean that for a table to be mapped, it must
+ specify a primary key. Additionally, if the table is detected as being
+ a pure association table between two other tables, it will not be directly
+ mapped and will instead be configured as a many-to-many table between
+ the mappings for the two referring tables.
+
+Generating Mappings from an Existing MetaData
+=============================================
+
+We can pass a pre-declared :class:`_schema.MetaData` object to
+:func:`.automap_base`.
+This object can be constructed in any way, including programmatically, from
+a serialized file, or from itself being reflected using
+:meth:`_schema.MetaData.reflect`.
+Below we illustrate a combination of reflection and
+explicit table declaration::
+
+ from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey
+ from sqlalchemy.ext.automap import automap_base
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ # produce our own MetaData object
+ metadata = MetaData()
+
+ # we can reflect it ourselves from a database, using options
+ # such as 'only' to limit what tables we look at...
+ metadata.reflect(engine, only=['user', 'address'])
+
+ # ... or just define our own Table objects with it (or combine both)
+ Table('user_order', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('user_id', ForeignKey('user.id'))
+ )
+
+ # we can then produce a set of mappings from this MetaData.
+ Base = automap_base(metadata=metadata)
+
+ # calling prepare() just sets up mapped classes and relationships.
+ Base.prepare()
+
+ # mapped classes are ready
+ User, Address, Order = Base.classes.user, Base.classes.address,\
+ Base.classes.user_order
+
+Specifying Classes Explicitly
+=============================
+
+.. tip:: If explicit classes are expected to be prominent in an application,
+ consider using :class:`.DeferredReflection` instead.
+
+The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined
+explicitly, in a way similar to that of the :class:`.DeferredReflection` class.
+Classes that extend from :class:`.AutomapBase` act like regular declarative
+classes, but are not immediately mapped after their construction, and are
+instead mapped when we call :meth:`.AutomapBase.prepare`. The
+:meth:`.AutomapBase.prepare` method will make use of the classes we've
+established based on the table name we use. If our schema contains tables
+``user`` and ``address``, we can define one or both of the classes to be used::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import create_engine
+
+ # automap base
+ Base = automap_base()
+
+ # pre-declare User for the 'user' table
+ class User(Base):
+ __tablename__ = 'user'
+
+ # override schema elements like Columns
+ user_name = Column('name', String)
+
+ # override relationships too, if desired.
+ # we must use the same name that automap would use for the
+ # relationship, and also must refer to the class name that automap will
+ # generate for "address"
+ address_collection = relationship("address", collection_class=set)
+
+ # reflect
+ engine = create_engine("sqlite:///mydatabase.db")
+ Base.prepare(autoload_with=engine)
+
+ # we still have Address generated from the tablename "address",
+ # but User is the same as Base.classes.User now
+
+ Address = Base.classes.address
+
+ u1 = session.query(User).first()
+ print (u1.address_collection)
+
+ # the backref is still there:
+ a1 = session.query(Address).first()
+ print (a1.user)
+
+Above, one of the more intricate details is that we illustrated overriding
+one of the :func:`_orm.relationship` objects that automap would have created.
+To do this, we needed to make sure the names match up with what automap
+would normally generate, in that the relationship name would be
+``User.address_collection`` and the name of the class referred to, from
+automap's perspective, is called ``address``, even though we are referring to
+it as ``Address`` within our usage of this class.
+
+Overriding Naming Schemes
+=========================
+
+:mod:`.sqlalchemy.ext.automap` is tasked with producing mapped classes and
+relationship names based on a schema, which means it has decision points in how
+these names are determined. These three decision points are provided using
+functions which can be passed to the :meth:`.AutomapBase.prepare` method, and
+are known as :func:`.classname_for_table`,
+:func:`.name_for_scalar_relationship`,
+and :func:`.name_for_collection_relationship`. Any or all of these
+functions are provided as in the example below, where we use a "camel case"
+scheme for class names and a "pluralizer" for collection names using the
+`Inflect <https://pypi.org/project/inflect>`_ package::
+
+ import re
+ import inflect
+
+ def camelize_classname(base, tablename, table):
+ "Produce a 'camelized' class name, e.g. "
+ "'words_and_underscores' -> 'WordsAndUnderscores'"
+
+ return str(tablename[0].upper() + \
+ re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:]))
+
+ _pluralizer = inflect.engine()
+ def pluralize_collection(base, local_cls, referred_cls, constraint):
+ "Produce an 'uncamelized', 'pluralized' class name, e.g. "
+ "'SomeTerm' -> 'some_terms'"
+
+ referred_name = referred_cls.__name__
+ uncamelized = re.sub(r'[A-Z]',
+ lambda m: "_%s" % m.group(0).lower(),
+ referred_name)[1:]
+ pluralized = _pluralizer.plural(uncamelized)
+ return pluralized
+
+ from sqlalchemy.ext.automap import automap_base
+
+ Base = automap_base()
+
+ engine = create_engine("sqlite:///mydatabase.db")
+
+ Base.prepare(autoload_with=engine,
+ classname_for_table=camelize_classname,
+ name_for_collection_relationship=pluralize_collection
+ )
+
+From the above mapping, we would now have classes ``User`` and ``Address``,
+where the collection from ``User`` to ``Address`` is called
+``User.addresses``::
+
+ User, Address = Base.classes.User, Base.classes.Address
+
+ u1 = User(addresses=[Address(email="foo@bar.com")])
+
+Relationship Detection
+======================
+
+The vast majority of what automap accomplishes is the generation of
+:func:`_orm.relationship` structures based on foreign keys. The mechanism
+by which this works for many-to-one and one-to-many relationships is as
+follows:
+
+1. A given :class:`_schema.Table`, known to be mapped to a particular class,
+ is examined for :class:`_schema.ForeignKeyConstraint` objects.
+
+2. From each :class:`_schema.ForeignKeyConstraint`, the remote
+ :class:`_schema.Table`
+ object present is matched up to the class to which it is to be mapped,
+ if any, else it is skipped.
+
+3. As the :class:`_schema.ForeignKeyConstraint`
+ we are examining corresponds to a
+ reference from the immediate mapped class, the relationship will be set up
+ as a many-to-one referring to the referred class; a corresponding
+ one-to-many backref will be created on the referred class referring
+ to this class.
+
+4. If any of the columns that are part of the
+ :class:`_schema.ForeignKeyConstraint`
+ are not nullable (e.g. ``nullable=False``), a
+ :paramref:`_orm.relationship.cascade` keyword argument
+ of ``all, delete-orphan`` will be added to the keyword arguments to
+ be passed to the relationship or backref. If the
+ :class:`_schema.ForeignKeyConstraint` reports that
+ :paramref:`_schema.ForeignKeyConstraint.ondelete`
+ is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable
+ set of columns, the option :paramref:`_orm.relationship.passive_deletes`
+ flag is set to ``True`` in the set of relationship keyword arguments.
+ Note that not all backends support reflection of ON DELETE.
+
+ .. versionadded:: 1.0.0 - automap will detect non-nullable foreign key
+ constraints when producing a one-to-many relationship and establish
+ a default cascade of ``all, delete-orphan`` if so; additionally,
+ if the constraint specifies
+ :paramref:`_schema.ForeignKeyConstraint.ondelete`
+ of ``CASCADE`` for non-nullable or ``SET NULL`` for nullable columns,
+ the ``passive_deletes=True`` option is also added.
+
+5. The names of the relationships are determined using the
+ :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and
+ :paramref:`.AutomapBase.prepare.name_for_collection_relationship`
+ callable functions. It is important to note that the default relationship
+ naming derives the name from the **the actual class name**. If you've
+ given a particular class an explicit name by declaring it, or specified an
+ alternate class naming scheme, that's the name from which the relationship
+ name will be derived.
+
+6. The classes are inspected for an existing mapped property matching these
+ names. If one is detected on one side, but none on the other side,
+ :class:`.AutomapBase` attempts to create a relationship on the missing side,
+ then uses the :paramref:`_orm.relationship.back_populates`
+ parameter in order to
+ point the new relationship to the other side.
+
+7. In the usual case where no relationship is on either side,
+ :meth:`.AutomapBase.prepare` produces a :func:`_orm.relationship` on the
+ "many-to-one" side and matches it to the other using the
+ :paramref:`_orm.relationship.backref` parameter.
+
+8. Production of the :func:`_orm.relationship` and optionally the
+ :func:`.backref`
+ is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship`
+ function, which can be supplied by the end-user in order to augment
+ the arguments passed to :func:`_orm.relationship` or :func:`.backref` or to
+ make use of custom implementations of these functions.
+
+Custom Relationship Arguments
+-----------------------------
+
+The :paramref:`.AutomapBase.prepare.generate_relationship` hook can be used
+to add parameters to relationships. For most cases, we can make use of the
+existing :func:`.automap.generate_relationship` function to return
+the object, after augmenting the given keyword dictionary with our own
+arguments.
+
+Below is an illustration of how to send
+:paramref:`_orm.relationship.cascade` and
+:paramref:`_orm.relationship.passive_deletes`
+options along to all one-to-many relationships::
+
+ from sqlalchemy.ext.automap import generate_relationship
+
+ def _gen_relationship(base, direction, return_fn,
+ attrname, local_cls, referred_cls, **kw):
+ if direction is interfaces.ONETOMANY:
+ kw['cascade'] = 'all, delete-orphan'
+ kw['passive_deletes'] = True
+ # make use of the built-in function to actually return
+ # the result.
+ return generate_relationship(base, direction, return_fn,
+ attrname, local_cls, referred_cls, **kw)
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import create_engine
+
+ # automap base
+ Base = automap_base()
+
+ engine = create_engine("sqlite:///mydatabase.db")
+ Base.prepare(autoload_with=engine,
+ generate_relationship=_gen_relationship)
+
+Many-to-Many relationships
+--------------------------
+
+:mod:`.sqlalchemy.ext.automap` will generate many-to-many relationships, e.g.
+those which contain a ``secondary`` argument. The process for producing these
+is as follows:
+
+1. A given :class:`_schema.Table` is examined for
+ :class:`_schema.ForeignKeyConstraint`
+ objects, before any mapped class has been assigned to it.
+
+2. If the table contains two and exactly two
+ :class:`_schema.ForeignKeyConstraint`
+ objects, and all columns within this table are members of these two
+ :class:`_schema.ForeignKeyConstraint` objects, the table is assumed to be a
+ "secondary" table, and will **not be mapped directly**.
+
+3. The two (or one, for self-referential) external tables to which the
+ :class:`_schema.Table`
+ refers to are matched to the classes to which they will be
+ mapped, if any.
+
+4. If mapped classes for both sides are located, a many-to-many bi-directional
+ :func:`_orm.relationship` / :func:`.backref`
+ pair is created between the two
+ classes.
+
+5. The override logic for many-to-many works the same as that of one-to-many/
+ many-to-one; the :func:`.generate_relationship` function is called upon
+ to generate the structures and existing attributes will be maintained.
+
+Relationships with Inheritance
+------------------------------
+
+:mod:`.sqlalchemy.ext.automap` will not generate any relationships between
+two classes that are in an inheritance relationship. That is, with two
+classes given as follows::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50))
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee', 'polymorphic_on': type
+ }
+
+ class Engineer(Employee):
+ __tablename__ = 'engineer'
+ id = Column(Integer, ForeignKey('employee.id'), primary_key=True)
+ __mapper_args__ = {
+ 'polymorphic_identity':'engineer',
+ }
+
+The foreign key from ``Engineer`` to ``Employee`` is used not for a
+relationship, but to establish joined inheritance between the two classes.
+
+Note that this means automap will not generate *any* relationships
+for foreign keys that link from a subclass to a superclass. If a mapping
+has actual relationships from subclass to superclass as well, those
+need to be explicit. Below, as we have two separate foreign keys
+from ``Engineer`` to ``Employee``, we need to set up both the relationship
+we want as well as the ``inherit_condition``, as these are not things
+SQLAlchemy can guess::
+
+ class Employee(Base):
+ __tablename__ = 'employee'
+ id = Column(Integer, primary_key=True)
+ type = Column(String(50))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee', 'polymorphic_on':type
+ }
+
+ class Engineer(Employee):
+ __tablename__ = 'engineer'
+ id = Column(Integer, ForeignKey('employee.id'), primary_key=True)
+ favorite_employee_id = Column(Integer, ForeignKey('employee.id'))
+
+ favorite_employee = relationship(Employee,
+ foreign_keys=favorite_employee_id)
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'engineer',
+ 'inherit_condition': id == Employee.id
+ }
+
+Handling Simple Naming Conflicts
+--------------------------------
+
+In the case of naming conflicts during mapping, override any of
+:func:`.classname_for_table`, :func:`.name_for_scalar_relationship`,
+and :func:`.name_for_collection_relationship` as needed. For example, if
+automap is attempting to name a many-to-one relationship the same as an
+existing column, an alternate convention can be conditionally selected. Given
+a schema:
+
+.. sourcecode:: sql
+
+ CREATE TABLE table_a (
+ id INTEGER PRIMARY KEY
+ );
+
+ CREATE TABLE table_b (
+ id INTEGER PRIMARY KEY,
+ table_a INTEGER,
+ FOREIGN KEY(table_a) REFERENCES table_a(id)
+ );
+
+The above schema will first automap the ``table_a`` table as a class named
+``table_a``; it will then automap a relationship onto the class for ``table_b``
+with the same name as this related class, e.g. ``table_a``. This
+relationship name conflicts with the mapping column ``table_b.table_a``,
+and will emit an error on mapping.
+
+We can resolve this conflict by using an underscore as follows::
+
+ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
+ name = referred_cls.__name__.lower()
+ local_table = local_cls.__table__
+ if name in local_table.columns:
+ newname = name + "_"
+ warnings.warn(
+ "Already detected name %s present. using %s" %
+ (name, newname))
+ return newname
+ return name
+
+
+ Base.prepare(autoload_with=engine,
+ name_for_scalar_relationship=name_for_scalar_relationship)
+
+Alternatively, we can change the name on the column side. The columns
+that are mapped can be modified using the technique described at
+:ref:`mapper_column_distinct_names`, by assigning the column explicitly
+to a new name::
+
+ Base = automap_base()
+
+ class TableB(Base):
+ __tablename__ = 'table_b'
+ _table_a = Column('table_a', ForeignKey('table_a.id'))
+
+ Base.prepare(autoload_with=engine)
+
+
+Using Automap with Explicit Declarations
+========================================
+
+As noted previously, automap has no dependency on reflection, and can make
+use of any collection of :class:`_schema.Table` objects within a
+:class:`_schema.MetaData`
+collection. From this, it follows that automap can also be used
+generate missing relationships given an otherwise complete model that fully
+defines table metadata::
+
+ from sqlalchemy.ext.automap import automap_base
+ from sqlalchemy import Column, Integer, String, ForeignKey
+
+ Base = automap_base()
+
+ class User(Base):
+ __tablename__ = 'user'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ class Address(Base):
+ __tablename__ = 'address'
+
+ id = Column(Integer, primary_key=True)
+ email = Column(String)
+ user_id = Column(ForeignKey('user.id'))
+
+ # produce relationships
+ Base.prepare()
+
+ # mapping is complete, with "address_collection" and
+ # "user" relationships
+ a1 = Address(email='u1')
+ a2 = Address(email='u2')
+ u1 = User(address_collection=[a1, a2])
+ assert a1.user is u1
+
+Above, given mostly complete ``User`` and ``Address`` mappings, the
+:class:`_schema.ForeignKey` which we defined on ``Address.user_id`` allowed a
+bidirectional relationship pair ``Address.user`` and
+``User.address_collection`` to be generated on the mapped classes.
+
+Note that when subclassing :class:`.AutomapBase`,
+the :meth:`.AutomapBase.prepare` method is required; if not called, the classes
+we've declared are in an un-mapped state.
+
+
+.. _automap_intercepting_columns:
+
+Intercepting Column Definitions
+===============================
+
+The :class:`_schema.MetaData` and :class:`_schema.Table` objects support an
+event hook :meth:`_events.DDLEvents.column_reflect` that may be used to intercept
+the information reflected about a database column before the :class:`_schema.Column`
+object is constructed. For example if we wanted to map columns using a
+naming convention such as ``"attr_<columnname>"``, the event could
+be applied as::
+
+ @event.listens_for(Base.metadata, "column_reflect")
+ def column_reflect(inspector, table, column_info):
+ # set column.key = "attr_<lower_case_name>"
+ column_info['key'] = "attr_%s" % column_info['name'].lower()
+
+ # run reflection
+ Base.prepare(autoload_with=engine)
+
+.. versionadded:: 1.4.0b2 the :meth:`_events.DDLEvents.column_reflect` event
+ may be applied to a :class:`_schema.MetaData` object.
+
+.. seealso::
+
+ :meth:`_events.DDLEvents.column_reflect`
+
+ :ref:`mapper_automated_reflection_schemes` - in the ORM mapping documentation
+
+
+""" # noqa
+from .. import util
+from ..orm import backref
+from ..orm import declarative_base as _declarative_base
+from ..orm import exc as orm_exc
+from ..orm import interfaces
+from ..orm import relationship
+from ..orm.decl_base import _DeferredMapperConfig
+from ..orm.mapper import _CONFIGURE_MUTEX
+from ..schema import ForeignKeyConstraint
+from ..sql import and_
+
+
+def classname_for_table(base, tablename, table):
+ """Return the class name that should be used, given the name
+ of a table.
+
+ The default implementation is::
+
+ return str(tablename)
+
+ Alternate implementations can be specified using the
+ :paramref:`.AutomapBase.prepare.classname_for_table`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param tablename: string name of the :class:`_schema.Table`.
+
+ :param table: the :class:`_schema.Table` object itself.
+
+ :return: a string class name.
+
+ .. note::
+
+ In Python 2, the string used for the class name **must** be a
+ non-Unicode object, e.g. a ``str()`` object. The ``.name`` attribute
+ of :class:`_schema.Table` is typically a Python unicode subclass,
+ so the
+ ``str()`` function should be applied to this name, after accounting for
+ any non-ASCII characters.
+
+ """
+ return str(tablename)
+
+
+def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
+ """Return the attribute name that should be used to refer from one
+ class to another, for a scalar object reference.
+
+ The default implementation is::
+
+ return referred_cls.__name__.lower()
+
+ Alternate implementations can be specified using the
+ :paramref:`.AutomapBase.prepare.name_for_scalar_relationship`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param local_cls: the class to be mapped on the local side.
+
+ :param referred_cls: the class to be mapped on the referring side.
+
+ :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being
+ inspected to produce this relationship.
+
+ """
+ return referred_cls.__name__.lower()
+
+
+def name_for_collection_relationship(
+ base, local_cls, referred_cls, constraint
+):
+ """Return the attribute name that should be used to refer from one
+ class to another, for a collection reference.
+
+ The default implementation is::
+
+ return referred_cls.__name__.lower() + "_collection"
+
+ Alternate implementations
+ can be specified using the
+ :paramref:`.AutomapBase.prepare.name_for_collection_relationship`
+ parameter.
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param local_cls: the class to be mapped on the local side.
+
+ :param referred_cls: the class to be mapped on the referring side.
+
+ :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being
+ inspected to produce this relationship.
+
+ """
+ return referred_cls.__name__.lower() + "_collection"
+
+
+def generate_relationship(
+ base, direction, return_fn, attrname, local_cls, referred_cls, **kw
+):
+ r"""Generate a :func:`_orm.relationship` or :func:`.backref`
+ on behalf of two
+ mapped classes.
+
+ An alternate implementation of this function can be specified using the
+ :paramref:`.AutomapBase.prepare.generate_relationship` parameter.
+
+ The default implementation of this function is as follows::
+
+ if return_fn is backref:
+ return return_fn(attrname, **kw)
+ elif return_fn is relationship:
+ return return_fn(referred_cls, **kw)
+ else:
+ raise TypeError("Unknown relationship function: %s" % return_fn)
+
+ :param base: the :class:`.AutomapBase` class doing the prepare.
+
+ :param direction: indicate the "direction" of the relationship; this will
+ be one of :data:`.ONETOMANY`, :data:`.MANYTOONE`, :data:`.MANYTOMANY`.
+
+ :param return_fn: the function that is used by default to create the
+ relationship. This will be either :func:`_orm.relationship` or
+ :func:`.backref`. The :func:`.backref` function's result will be used to
+ produce a new :func:`_orm.relationship` in a second step,
+ so it is critical
+ that user-defined implementations correctly differentiate between the two
+ functions, if a custom relationship function is being used.
+
+ :param attrname: the attribute name to which this relationship is being
+ assigned. If the value of :paramref:`.generate_relationship.return_fn` is
+ the :func:`.backref` function, then this name is the name that is being
+ assigned to the backref.
+
+ :param local_cls: the "local" class to which this relationship or backref
+ will be locally present.
+
+ :param referred_cls: the "referred" class to which the relationship or
+ backref refers to.
+
+ :param \**kw: all additional keyword arguments are passed along to the
+ function.
+
+ :return: a :func:`_orm.relationship` or :func:`.backref` construct,
+ as dictated
+ by the :paramref:`.generate_relationship.return_fn` parameter.
+
+ """
+ if return_fn is backref:
+ return return_fn(attrname, **kw)
+ elif return_fn is relationship:
+ return return_fn(referred_cls, **kw)
+ else:
+ raise TypeError("Unknown relationship function: %s" % return_fn)
+
+
+class AutomapBase(object):
+ """Base class for an "automap" schema.
+
+ The :class:`.AutomapBase` class can be compared to the "declarative base"
+ class that is produced by the :func:`.declarative.declarative_base`
+ function. In practice, the :class:`.AutomapBase` class is always used
+ as a mixin along with an actual declarative base.
+
+ A new subclassable :class:`.AutomapBase` is typically instantiated
+ using the :func:`.automap_base` function.
+
+ .. seealso::
+
+ :ref:`automap_toplevel`
+
+ """
+
+ __abstract__ = True
+
+ classes = None
+ """An instance of :class:`.util.Properties` containing classes.
+
+ This object behaves much like the ``.c`` collection on a table. Classes
+ are present under the name they were given, e.g.::
+
+ Base = automap_base()
+ Base.prepare(autoload_with=some_engine)
+
+ User, Address = Base.classes.User, Base.classes.Address
+
+ """
+
+ @classmethod
+ @util.deprecated_params(
+ engine=(
+ "2.0",
+ "The :paramref:`_automap.AutomapBase.prepare.engine` parameter "
+ "is deprecated and will be removed in a future release. "
+ "Please use the "
+ ":paramref:`_automap.AutomapBase.prepare.autoload_with` "
+ "parameter.",
+ ),
+ reflect=(
+ "2.0",
+ "The :paramref:`_automap.AutomapBase.prepare.reflect` "
+ "parameter is deprecated and will be removed in a future "
+ "release. Reflection is enabled when "
+ ":paramref:`_automap.AutomapBase.prepare.autoload_with` "
+ "is passed.",
+ ),
+ )
+ def prepare(
+ cls,
+ autoload_with=None,
+ engine=None,
+ reflect=False,
+ schema=None,
+ classname_for_table=None,
+ collection_class=None,
+ name_for_scalar_relationship=None,
+ name_for_collection_relationship=None,
+ generate_relationship=None,
+ reflection_options=util.EMPTY_DICT,
+ ):
+ """Extract mapped classes and relationships from the
+ :class:`_schema.MetaData` and
+ perform mappings.
+
+ :param engine: an :class:`_engine.Engine` or
+ :class:`_engine.Connection` with which
+ to perform schema reflection, if specified.
+ If the :paramref:`.AutomapBase.prepare.reflect` argument is False,
+ this object is not used.
+
+ :param reflect: if True, the :meth:`_schema.MetaData.reflect`
+ method is called
+ on the :class:`_schema.MetaData` associated with this
+ :class:`.AutomapBase`.
+ The :class:`_engine.Engine` passed via
+ :paramref:`.AutomapBase.prepare.engine` will be used to perform the
+ reflection if present; else, the :class:`_schema.MetaData`
+ should already be
+ bound to some engine else the operation will fail.
+
+ :param classname_for_table: callable function which will be used to
+ produce new class names, given a table name. Defaults to
+ :func:`.classname_for_table`.
+
+ :param name_for_scalar_relationship: callable function which will be
+ used to produce relationship names for scalar relationships. Defaults
+ to :func:`.name_for_scalar_relationship`.
+
+ :param name_for_collection_relationship: callable function which will
+ be used to produce relationship names for collection-oriented
+ relationships. Defaults to :func:`.name_for_collection_relationship`.
+
+ :param generate_relationship: callable function which will be used to
+ actually generate :func:`_orm.relationship` and :func:`.backref`
+ constructs. Defaults to :func:`.generate_relationship`.
+
+ :param collection_class: the Python collection class that will be used
+ when a new :func:`_orm.relationship`
+ object is created that represents a
+ collection. Defaults to ``list``.
+
+ :param schema: When present in conjunction with the
+ :paramref:`.AutomapBase.prepare.reflect` flag, is passed to
+ :meth:`_schema.MetaData.reflect`
+ to indicate the primary schema where tables
+ should be reflected from. When omitted, the default schema in use
+ by the database connection is used.
+
+ .. versionadded:: 1.1
+
+ :param reflection_options: When present, this dictionary of options
+ will be passed to :meth:`_schema.MetaData.reflect`
+ to supply general reflection-specific options like ``only`` and/or
+ dialect-specific options like ``oracle_resolve_synonyms``.
+
+ .. versionadded:: 1.4
+
+ """
+ glbls = globals()
+ if classname_for_table is None:
+ classname_for_table = glbls["classname_for_table"]
+ if name_for_scalar_relationship is None:
+ name_for_scalar_relationship = glbls[
+ "name_for_scalar_relationship"
+ ]
+ if name_for_collection_relationship is None:
+ name_for_collection_relationship = glbls[
+ "name_for_collection_relationship"
+ ]
+ if generate_relationship is None:
+ generate_relationship = glbls["generate_relationship"]
+ if collection_class is None:
+ collection_class = list
+
+ if autoload_with:
+ reflect = True
+
+ if engine:
+ autoload_with = engine
+
+ if reflect:
+ opts = dict(
+ schema=schema,
+ extend_existing=True,
+ autoload_replace=False,
+ )
+ if reflection_options:
+ opts.update(reflection_options)
+ cls.metadata.reflect(autoload_with, **opts)
+
+ with _CONFIGURE_MUTEX:
+ table_to_map_config = dict(
+ (m.local_table, m)
+ for m in _DeferredMapperConfig.classes_for_base(
+ cls, sort=False
+ )
+ )
+
+ many_to_many = []
+
+ for table in cls.metadata.tables.values():
+ lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table)
+ if lcl_m2m is not None:
+ many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table))
+ elif not table.primary_key:
+ continue
+ elif table not in table_to_map_config:
+ mapped_cls = type(
+ classname_for_table(cls, table.name, table),
+ (cls,),
+ {"__table__": table},
+ )
+ map_config = _DeferredMapperConfig.config_for_cls(
+ mapped_cls
+ )
+ cls.classes[map_config.cls.__name__] = mapped_cls
+ table_to_map_config[table] = map_config
+
+ for map_config in table_to_map_config.values():
+ _relationships_for_fks(
+ cls,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
+
+ for lcl_m2m, rem_m2m, m2m_const, table in many_to_many:
+ _m2m_relationship(
+ cls,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+ )
+
+ for map_config in _DeferredMapperConfig.classes_for_base(cls):
+ map_config.map()
+
+ _sa_decl_prepare = True
+ """Indicate that the mapping of classes should be deferred.
+
+ The presence of this attribute name indicates to declarative
+ that the call to mapper() should not occur immediately; instead,
+ information about the table and attributes to be mapped are gathered
+ into an internal structure called _DeferredMapperConfig. These
+ objects can be collected later using classes_for_base(), additional
+ mapping decisions can be made, and then the map() method will actually
+ apply the mapping.
+
+ The only real reason this deferral of the whole
+ thing is needed is to support primary key columns that aren't reflected
+ yet when the class is declared; everything else can theoretically be
+ added to the mapper later. However, the _DeferredMapperConfig is a
+ nice interface in any case which exists at that not usually exposed point
+ at which declarative has the class and the Table but hasn't called
+ mapper() yet.
+
+ """
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of AutomapBase. "
+ "Mappings are not produced until the .prepare() "
+ "method is called on the class hierarchy."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+
+def automap_base(declarative_base=None, **kw):
+ r"""Produce a declarative automap base.
+
+ This function produces a new base class that is a product of the
+ :class:`.AutomapBase` class as well a declarative base produced by
+ :func:`.declarative.declarative_base`.
+
+ All parameters other than ``declarative_base`` are keyword arguments
+ that are passed directly to the :func:`.declarative.declarative_base`
+ function.
+
+ :param declarative_base: an existing class produced by
+ :func:`.declarative.declarative_base`. When this is passed, the function
+ no longer invokes :func:`.declarative.declarative_base` itself, and all
+ other keyword arguments are ignored.
+
+ :param \**kw: keyword arguments are passed along to
+ :func:`.declarative.declarative_base`.
+
+ """
+ if declarative_base is None:
+ Base = _declarative_base(**kw)
+ else:
+ Base = declarative_base
+
+ return type(
+ Base.__name__,
+ (AutomapBase, Base),
+ {"__abstract__": True, "classes": util.Properties({})},
+ )
+
+
+def _is_many_to_many(automap_base, table):
+ fk_constraints = [
+ const
+ for const in table.constraints
+ if isinstance(const, ForeignKeyConstraint)
+ ]
+ if len(fk_constraints) != 2:
+ return None, None, None
+
+ cols = sum(
+ [
+ [fk.parent for fk in fk_constraint.elements]
+ for fk_constraint in fk_constraints
+ ],
+ [],
+ )
+
+ if set(cols) != set(table.c):
+ return None, None, None
+
+ return (
+ fk_constraints[0].elements[0].column.table,
+ fk_constraints[1].elements[0].column.table,
+ fk_constraints,
+ )
+
+
+def _relationships_for_fks(
+ automap_base,
+ map_config,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
+ local_table = map_config.local_table
+ local_cls = map_config.cls # derived from a weakref, may be None
+
+ if local_table is None or local_cls is None:
+ return
+ for constraint in local_table.constraints:
+ if isinstance(constraint, ForeignKeyConstraint):
+ fks = constraint.elements
+ referred_table = fks[0].column.table
+ referred_cfg = table_to_map_config.get(referred_table, None)
+ if referred_cfg is None:
+ continue
+ referred_cls = referred_cfg.cls
+
+ if local_cls is not referred_cls and issubclass(
+ local_cls, referred_cls
+ ):
+ continue
+
+ relationship_name = name_for_scalar_relationship(
+ automap_base, local_cls, referred_cls, constraint
+ )
+ backref_name = name_for_collection_relationship(
+ automap_base, referred_cls, local_cls, constraint
+ )
+
+ o2m_kws = {}
+ nullable = False not in {fk.parent.nullable for fk in fks}
+ if not nullable:
+ o2m_kws["cascade"] = "all, delete-orphan"
+
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "cascade"
+ ):
+ o2m_kws["passive_deletes"] = True
+ else:
+ if (
+ constraint.ondelete
+ and constraint.ondelete.lower() == "set null"
+ ):
+ o2m_kws["passive_deletes"] = True
+
+ create_backref = backref_name not in referred_cfg.properties
+
+ if relationship_name not in map_config.properties:
+ if create_backref:
+ backref_obj = generate_relationship(
+ automap_base,
+ interfaces.ONETOMANY,
+ backref,
+ backref_name,
+ referred_cls,
+ local_cls,
+ collection_class=collection_class,
+ **o2m_kws
+ )
+ else:
+ backref_obj = None
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOONE,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ backref=backref_obj,
+ remote_side=[fk.column for fk in constraint.elements],
+ )
+ if rel is not None:
+ map_config.properties[relationship_name] = rel
+ if not create_backref:
+ referred_cfg.properties[
+ backref_name
+ ].back_populates = relationship_name
+ elif create_backref:
+ rel = generate_relationship(
+ automap_base,
+ interfaces.ONETOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ foreign_keys=[fk.parent for fk in constraint.elements],
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ **o2m_kws
+ )
+ if rel is not None:
+ referred_cfg.properties[backref_name] = rel
+ map_config.properties[
+ relationship_name
+ ].back_populates = backref_name
+
+
+def _m2m_relationship(
+ automap_base,
+ lcl_m2m,
+ rem_m2m,
+ m2m_const,
+ table,
+ table_to_map_config,
+ collection_class,
+ name_for_scalar_relationship,
+ name_for_collection_relationship,
+ generate_relationship,
+):
+
+ map_config = table_to_map_config.get(lcl_m2m, None)
+ referred_cfg = table_to_map_config.get(rem_m2m, None)
+ if map_config is None or referred_cfg is None:
+ return
+
+ local_cls = map_config.cls
+ referred_cls = referred_cfg.cls
+
+ relationship_name = name_for_collection_relationship(
+ automap_base, local_cls, referred_cls, m2m_const[0]
+ )
+ backref_name = name_for_collection_relationship(
+ automap_base, referred_cls, local_cls, m2m_const[1]
+ )
+
+ create_backref = backref_name not in referred_cfg.properties
+
+ if table in table_to_map_config:
+ overlaps = "__*"
+ else:
+ overlaps = None
+
+ if relationship_name not in map_config.properties:
+ if create_backref:
+ backref_obj = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ backref,
+ backref_name,
+ referred_cls,
+ local_cls,
+ collection_class=collection_class,
+ overlaps=overlaps,
+ )
+ else:
+ backref_obj = None
+
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ relationship_name,
+ local_cls,
+ referred_cls,
+ overlaps=overlaps,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ backref=backref_obj,
+ collection_class=collection_class,
+ )
+ if rel is not None:
+ map_config.properties[relationship_name] = rel
+
+ if not create_backref:
+ referred_cfg.properties[
+ backref_name
+ ].back_populates = relationship_name
+ elif create_backref:
+ rel = generate_relationship(
+ automap_base,
+ interfaces.MANYTOMANY,
+ relationship,
+ backref_name,
+ referred_cls,
+ local_cls,
+ overlaps=overlaps,
+ secondary=table,
+ primaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[1].elements
+ ),
+ secondaryjoin=and_(
+ fk.column == fk.parent for fk in m2m_const[0].elements
+ ),
+ back_populates=relationship_name,
+ collection_class=collection_class,
+ )
+ if rel is not None:
+ referred_cfg.properties[backref_name] = rel
+ map_config.properties[
+ relationship_name
+ ].back_populates = backref_name
diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py
new file mode 100644
index 0000000..109e0c0
--- /dev/null
+++ b/lib/sqlalchemy/ext/baked.py
@@ -0,0 +1,648 @@
+# sqlalchemy/ext/baked.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Baked query extension.
+
+Provides a creational pattern for the :class:`.query.Query` object which
+allows the fully constructed object, Core select statement, and string
+compiled result to be fully cached.
+
+
+"""
+
+import logging
+
+from .. import exc as sa_exc
+from .. import util
+from ..orm import exc as orm_exc
+from ..orm import strategy_options
+from ..orm.query import Query
+from ..orm.session import Session
+from ..sql import func
+from ..sql import literal_column
+from ..sql import util as sql_util
+from ..util import collections_abc
+
+
+log = logging.getLogger(__name__)
+
+
+class Bakery(object):
+ """Callable which returns a :class:`.BakedQuery`.
+
+ This object is returned by the class method
+ :meth:`.BakedQuery.bakery`. It exists as an object
+ so that the "cache" can be easily inspected.
+
+ .. versionadded:: 1.2
+
+
+ """
+
+ __slots__ = "cls", "cache"
+
+ def __init__(self, cls_, cache):
+ self.cls = cls_
+ self.cache = cache
+
+ def __call__(self, initial_fn, *args):
+ return self.cls(self.cache, initial_fn, args)
+
+
+class BakedQuery(object):
+ """A builder object for :class:`.query.Query` objects."""
+
+ __slots__ = "steps", "_bakery", "_cache_key", "_spoiled"
+
+ def __init__(self, bakery, initial_fn, args=()):
+ self._cache_key = ()
+ self._update_cache_key(initial_fn, args)
+ self.steps = [initial_fn]
+ self._spoiled = False
+ self._bakery = bakery
+
+ @classmethod
+ def bakery(cls, size=200, _size_alert=None):
+ """Construct a new bakery.
+
+ :return: an instance of :class:`.Bakery`
+
+ """
+
+ return Bakery(cls, util.LRUCache(size, size_alert=_size_alert))
+
+ def _clone(self):
+ b1 = BakedQuery.__new__(BakedQuery)
+ b1._cache_key = self._cache_key
+ b1.steps = list(self.steps)
+ b1._bakery = self._bakery
+ b1._spoiled = self._spoiled
+ return b1
+
+ def _update_cache_key(self, fn, args=()):
+ self._cache_key += (fn.__code__,) + args
+
+ def __iadd__(self, other):
+ if isinstance(other, tuple):
+ self.add_criteria(*other)
+ else:
+ self.add_criteria(other)
+ return self
+
+ def __add__(self, other):
+ if isinstance(other, tuple):
+ return self.with_criteria(*other)
+ else:
+ return self.with_criteria(other)
+
+ def add_criteria(self, fn, *args):
+ """Add a criteria function to this :class:`.BakedQuery`.
+
+ This is equivalent to using the ``+=`` operator to
+ modify a :class:`.BakedQuery` in-place.
+
+ """
+ self._update_cache_key(fn, args)
+ self.steps.append(fn)
+ return self
+
+ def with_criteria(self, fn, *args):
+ """Add a criteria function to a :class:`.BakedQuery` cloned from this
+ one.
+
+ This is equivalent to using the ``+`` operator to
+ produce a new :class:`.BakedQuery` with modifications.
+
+ """
+ return self._clone().add_criteria(fn, *args)
+
+ def for_session(self, session):
+ """Return a :class:`_baked.Result` object for this
+ :class:`.BakedQuery`.
+
+ This is equivalent to calling the :class:`.BakedQuery` as a
+ Python callable, e.g. ``result = my_baked_query(session)``.
+
+ """
+ return Result(self, session)
+
+ def __call__(self, session):
+ return self.for_session(session)
+
+ def spoil(self, full=False):
+ """Cancel any query caching that will occur on this BakedQuery object.
+
+ The BakedQuery can continue to be used normally, however additional
+ creational functions will not be cached; they will be called
+ on every invocation.
+
+ This is to support the case where a particular step in constructing
+ a baked query disqualifies the query from being cacheable, such
+ as a variant that relies upon some uncacheable value.
+
+ :param full: if False, only functions added to this
+ :class:`.BakedQuery` object subsequent to the spoil step will be
+ non-cached; the state of the :class:`.BakedQuery` up until
+ this point will be pulled from the cache. If True, then the
+ entire :class:`_query.Query` object is built from scratch each
+ time, with all creational functions being called on each
+ invocation.
+
+ """
+ if not full and not self._spoiled:
+ _spoil_point = self._clone()
+ _spoil_point._cache_key += ("_query_only",)
+ self.steps = [_spoil_point._retrieve_baked_query]
+ self._spoiled = True
+ return self
+
+ def _effective_key(self, session):
+ """Return the key that actually goes into the cache dictionary for
+ this :class:`.BakedQuery`, taking into account the given
+ :class:`.Session`.
+
+ This basically means we also will include the session's query_class,
+ as the actual :class:`_query.Query` object is part of what's cached
+ and needs to match the type of :class:`_query.Query` that a later
+ session will want to use.
+
+ """
+ return self._cache_key + (session._query_cls,)
+
+ def _with_lazyload_options(self, options, effective_path, cache_path=None):
+ """Cloning version of _add_lazyload_options."""
+ q = self._clone()
+ q._add_lazyload_options(options, effective_path, cache_path=cache_path)
+ return q
+
+ def _add_lazyload_options(self, options, effective_path, cache_path=None):
+ """Used by per-state lazy loaders to add options to the
+ "lazy load" query from a parent query.
+
+ Creates a cache key based on given load path and query options;
+ if a repeatable cache key cannot be generated, the query is
+ "spoiled" so that it won't use caching.
+
+ """
+
+ key = ()
+
+ if not cache_path:
+ cache_path = effective_path
+
+ for opt in options:
+ if opt._is_legacy_option or opt._is_compile_state:
+ ck = opt._generate_cache_key()
+ if ck is None:
+ self.spoil(full=True)
+ else:
+ assert not ck[1], (
+ "loader options with variable bound parameters "
+ "not supported with baked queries. Please "
+ "use new-style select() statements for cached "
+ "ORM queries."
+ )
+ key += ck[0]
+
+ self.add_criteria(
+ lambda q: q._with_current_path(effective_path).options(*options),
+ cache_path.path,
+ key,
+ )
+
+ def _retrieve_baked_query(self, session):
+ query = self._bakery.get(self._effective_key(session), None)
+ if query is None:
+ query = self._as_query(session)
+ self._bakery[self._effective_key(session)] = query.with_session(
+ None
+ )
+ return query.with_session(session)
+
+ def _bake(self, session):
+ query = self._as_query(session)
+ query.session = None
+
+ # in 1.4, this is where before_compile() event is
+ # invoked
+ statement = query._statement_20()
+
+ # if the query is not safe to cache, we still do everything as though
+ # we did cache it, since the receiver of _bake() assumes subqueryload
+ # context was set up, etc.
+ #
+ # note also we want to cache the statement itself because this
+ # allows the statement itself to hold onto its cache key that is
+ # used by the Connection, which in itself is more expensive to
+ # generate than what BakedQuery was able to provide in 1.3 and prior
+
+ if statement._compile_options._bake_ok:
+ self._bakery[self._effective_key(session)] = (
+ query,
+ statement,
+ )
+
+ return query, statement
+
+ def to_query(self, query_or_session):
+ """Return the :class:`_query.Query` object for use as a subquery.
+
+ This method should be used within the lambda callable being used
+ to generate a step of an enclosing :class:`.BakedQuery`. The
+ parameter should normally be the :class:`_query.Query` object that
+ is passed to the lambda::
+
+ sub_bq = self.bakery(lambda s: s.query(User.name))
+ sub_bq += lambda q: q.filter(
+ User.id == Address.user_id).correlate(Address)
+
+ main_bq = self.bakery(lambda s: s.query(Address))
+ main_bq += lambda q: q.filter(
+ sub_bq.to_query(q).exists())
+
+ In the case where the subquery is used in the first callable against
+ a :class:`.Session`, the :class:`.Session` is also accepted::
+
+ sub_bq = self.bakery(lambda s: s.query(User.name))
+ sub_bq += lambda q: q.filter(
+ User.id == Address.user_id).correlate(Address)
+
+ main_bq = self.bakery(
+ lambda s: s.query(
+ Address.id, sub_bq.to_query(q).scalar_subquery())
+ )
+
+ :param query_or_session: a :class:`_query.Query` object or a class
+ :class:`.Session` object, that is assumed to be within the context
+ of an enclosing :class:`.BakedQuery` callable.
+
+
+ .. versionadded:: 1.3
+
+
+ """
+
+ if isinstance(query_or_session, Session):
+ session = query_or_session
+ elif isinstance(query_or_session, Query):
+ session = query_or_session.session
+ if session is None:
+ raise sa_exc.ArgumentError(
+ "Given Query needs to be associated with a Session"
+ )
+ else:
+ raise TypeError(
+ "Query or Session object expected, got %r."
+ % type(query_or_session)
+ )
+ return self._as_query(session)
+
+ def _as_query(self, session):
+ query = self.steps[0](session)
+
+ for step in self.steps[1:]:
+ query = step(query)
+
+ return query
+
+
+class Result(object):
+ """Invokes a :class:`.BakedQuery` against a :class:`.Session`.
+
+ The :class:`_baked.Result` object is where the actual :class:`.query.Query`
+ object gets created, or retrieved from the cache,
+ against a target :class:`.Session`, and is then invoked for results.
+
+ """
+
+ __slots__ = "bq", "session", "_params", "_post_criteria"
+
+ def __init__(self, bq, session):
+ self.bq = bq
+ self.session = session
+ self._params = {}
+ self._post_criteria = []
+
+ def params(self, *args, **kw):
+ """Specify parameters to be replaced into the string SQL statement."""
+
+ if len(args) == 1:
+ kw.update(args[0])
+ elif len(args) > 0:
+ raise sa_exc.ArgumentError(
+ "params() takes zero or one positional argument, "
+ "which is a dictionary."
+ )
+ self._params.update(kw)
+ return self
+
+ def _using_post_criteria(self, fns):
+ if fns:
+ self._post_criteria.extend(fns)
+ return self
+
+ def with_post_criteria(self, fn):
+ """Add a criteria function that will be applied post-cache.
+
+ This adds a function that will be run against the
+ :class:`_query.Query` object after it is retrieved from the
+ cache. This currently includes **only** the
+ :meth:`_query.Query.params` and :meth:`_query.Query.execution_options`
+ methods.
+
+ .. warning:: :meth:`_baked.Result.with_post_criteria`
+ functions are applied
+ to the :class:`_query.Query`
+ object **after** the query's SQL statement
+ object has been retrieved from the cache. Only
+ :meth:`_query.Query.params` and
+ :meth:`_query.Query.execution_options`
+ methods should be used.
+
+
+ .. versionadded:: 1.2
+
+
+ """
+ return self._using_post_criteria([fn])
+
+ def _as_query(self):
+ q = self.bq._as_query(self.session).params(self._params)
+ for fn in self._post_criteria:
+ q = fn(q)
+ return q
+
+ def __str__(self):
+ return str(self._as_query())
+
+ def __iter__(self):
+ return self._iter().__iter__()
+
+ def _iter(self):
+ bq = self.bq
+
+ if not self.session.enable_baked_queries or bq._spoiled:
+ return self._as_query()._iter()
+
+ query, statement = bq._bakery.get(
+ bq._effective_key(self.session), (None, None)
+ )
+ if query is None:
+ query, statement = bq._bake(self.session)
+
+ if self._params:
+ q = query.params(self._params)
+ else:
+ q = query
+ for fn in self._post_criteria:
+ q = fn(q)
+
+ params = q._params
+ execution_options = dict(q._execution_options)
+ execution_options.update(
+ {
+ "_sa_orm_load_options": q.load_options,
+ "compiled_cache": bq._bakery,
+ }
+ )
+
+ result = self.session.execute(
+ statement, params, execution_options=execution_options
+ )
+ if result._attributes.get("is_single_entity", False):
+ result = result.scalars()
+
+ if result._attributes.get("filtered", False):
+ result = result.unique()
+
+ return result
+
+ def count(self):
+ """return the 'count'.
+
+ Equivalent to :meth:`_query.Query.count`.
+
+ Note this uses a subquery to ensure an accurate count regardless
+ of the structure of the original statement.
+
+ .. versionadded:: 1.1.6
+
+ """
+
+ col = func.count(literal_column("*"))
+ bq = self.bq.with_criteria(lambda q: q._from_self(col))
+ return bq.for_session(self.session).params(self._params).scalar()
+
+ def scalar(self):
+ """Return the first element of the first result or None
+ if no rows present. If multiple rows are returned,
+ raises MultipleResultsFound.
+
+ Equivalent to :meth:`_query.Query.scalar`.
+
+ .. versionadded:: 1.1.6
+
+ """
+ try:
+ ret = self.one()
+ if not isinstance(ret, collections_abc.Sequence):
+ return ret
+ return ret[0]
+ except orm_exc.NoResultFound:
+ return None
+
+ def first(self):
+ """Return the first row.
+
+ Equivalent to :meth:`_query.Query.first`.
+
+ """
+
+ bq = self.bq.with_criteria(lambda q: q.slice(0, 1))
+ return (
+ bq.for_session(self.session)
+ .params(self._params)
+ ._using_post_criteria(self._post_criteria)
+ ._iter()
+ .first()
+ )
+
+ def one(self):
+ """Return exactly one result or raise an exception.
+
+ Equivalent to :meth:`_query.Query.one`.
+
+ """
+ return self._iter().one()
+
+ def one_or_none(self):
+ """Return one or zero results, or raise an exception for multiple
+ rows.
+
+ Equivalent to :meth:`_query.Query.one_or_none`.
+
+ .. versionadded:: 1.0.9
+
+ """
+ return self._iter().one_or_none()
+
+ def all(self):
+ """Return all rows.
+
+ Equivalent to :meth:`_query.Query.all`.
+
+ """
+ return self._iter().all()
+
+ def get(self, ident):
+ """Retrieve an object based on identity.
+
+ Equivalent to :meth:`_query.Query.get`.
+
+ """
+
+ query = self.bq.steps[0](self.session)
+ return query._get_impl(ident, self._load_on_pk_identity)
+
+ def _load_on_pk_identity(self, session, query, primary_key_identity, **kw):
+ """Load the given primary key identity from the database."""
+
+ mapper = query._raw_columns[0]._annotations["parententity"]
+
+ _get_clause, _get_params = mapper._get_clause
+
+ def setup(query):
+ _lcl_get_clause = _get_clause
+ q = query._clone()
+ q._get_condition()
+ q._order_by = None
+
+ # None present in ident - turn those comparisons
+ # into "IS NULL"
+ if None in primary_key_identity:
+ nones = set(
+ [
+ _get_params[col].key
+ for col, value in zip(
+ mapper.primary_key, primary_key_identity
+ )
+ if value is None
+ ]
+ )
+ _lcl_get_clause = sql_util.adapt_criterion_to_null(
+ _lcl_get_clause, nones
+ )
+
+ # TODO: can mapper._get_clause be pre-adapted?
+ q._where_criteria = (
+ sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}),
+ )
+
+ for fn in self._post_criteria:
+ q = fn(q)
+ return q
+
+ # cache the query against a key that includes
+ # which positions in the primary key are NULL
+ # (remember, we can map to an OUTER JOIN)
+ bq = self.bq
+
+ # add the clause we got from mapper._get_clause to the cache
+ # key so that if a race causes multiple calls to _get_clause,
+ # we've cached on ours
+ bq = bq._clone()
+ bq._cache_key += (_get_clause,)
+
+ bq = bq.with_criteria(
+ setup, tuple(elem is None for elem in primary_key_identity)
+ )
+
+ params = dict(
+ [
+ (_get_params[primary_key].key, id_val)
+ for id_val, primary_key in zip(
+ primary_key_identity, mapper.primary_key
+ )
+ ]
+ )
+
+ result = list(bq.for_session(self.session).params(**params))
+ l = len(result)
+ if l > 1:
+ raise orm_exc.MultipleResultsFound()
+ elif l:
+ return result[0]
+ else:
+ return None
+
+
+@util.deprecated(
+ "1.2", "Baked lazy loading is now the default implementation."
+)
+def bake_lazy_loaders():
+ """Enable the use of baked queries for all lazyloaders systemwide.
+
+ The "baked" implementation of lazy loading is now the sole implementation
+ for the base lazy loader; this method has no effect except for a warning.
+
+ """
+ pass
+
+
+@util.deprecated(
+ "1.2", "Baked lazy loading is now the default implementation."
+)
+def unbake_lazy_loaders():
+ """Disable the use of baked queries for all lazyloaders systemwide.
+
+ This method now raises NotImplementedError() as the "baked" implementation
+ is the only lazy load implementation. The
+ :paramref:`_orm.relationship.bake_queries` flag may be used to disable
+ the caching of queries on a per-relationship basis.
+
+ """
+ raise NotImplementedError(
+ "Baked lazy loading is now the default implementation"
+ )
+
+
+@strategy_options.loader_option()
+def baked_lazyload(loadopt, attr):
+ """Indicate that the given attribute should be loaded using "lazy"
+ loading with a "baked" query used in the load.
+
+ """
+ return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"})
+
+
+@baked_lazyload._add_unbound_fn
+@util.deprecated(
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
+def baked_lazyload(*keys):
+ return strategy_options._UnboundLoad._from_keys(
+ strategy_options._UnboundLoad.baked_lazyload, keys, False, {}
+ )
+
+
+@baked_lazyload._add_unbound_all_fn
+@util.deprecated(
+ "1.2",
+ "Baked lazy loading is now the default "
+ "implementation for lazy loading.",
+)
+def baked_lazyload_all(*keys):
+ return strategy_options._UnboundLoad._from_keys(
+ strategy_options._UnboundLoad.baked_lazyload, keys, True, {}
+ )
+
+
+baked_lazyload = baked_lazyload._unbound_fn
+baked_lazyload_all = baked_lazyload_all._unbound_all_fn
+
+bakery = BakedQuery.bakery
diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py
new file mode 100644
index 0000000..76b59ea
--- /dev/null
+++ b/lib/sqlalchemy/ext/compiler.py
@@ -0,0 +1,613 @@
+# ext/compiler.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Provides an API for creation of custom ClauseElements and compilers.
+
+Synopsis
+========
+
+Usage involves the creation of one or more
+:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or
+more callables defining its compilation::
+
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.sql.expression import ColumnClause
+
+ class MyColumn(ColumnClause):
+ inherit_cache = True
+
+ @compiles(MyColumn)
+ def compile_mycolumn(element, compiler, **kw):
+ return "[%s]" % element.name
+
+Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`,
+the base expression element for named column objects. The ``compiles``
+decorator registers itself with the ``MyColumn`` class so that it is invoked
+when the object is compiled to a string::
+
+ from sqlalchemy import select
+
+ s = select(MyColumn('x'), MyColumn('y'))
+ print(str(s))
+
+Produces::
+
+ SELECT [x], [y]
+
+Dialect-specific compilation rules
+==================================
+
+Compilers can also be made dialect-specific. The appropriate compiler will be
+invoked for the dialect in use::
+
+ from sqlalchemy.schema import DDLElement
+
+ class AlterColumn(DDLElement):
+ inherit_cache = False
+
+ def __init__(self, column, cmd):
+ self.column = column
+ self.cmd = cmd
+
+ @compiles(AlterColumn)
+ def visit_alter_column(element, compiler, **kw):
+ return "ALTER COLUMN %s ..." % element.column.name
+
+ @compiles(AlterColumn, 'postgresql')
+ def visit_alter_column(element, compiler, **kw):
+ return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name,
+ element.column.name)
+
+The second ``visit_alter_table`` will be invoked when any ``postgresql``
+dialect is used.
+
+.. _compilerext_compiling_subelements:
+
+Compiling sub-elements of a custom expression construct
+=======================================================
+
+The ``compiler`` argument is the
+:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object
+can be inspected for any information about the in-progress compilation,
+including ``compiler.dialect``, ``compiler.statement`` etc. The
+:class:`~sqlalchemy.sql.compiler.SQLCompiler` and
+:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()``
+method which can be used for compilation of embedded attributes::
+
+ from sqlalchemy.sql.expression import Executable, ClauseElement
+
+ class InsertFromSelect(Executable, ClauseElement):
+ inherit_cache = False
+
+ def __init__(self, table, select):
+ self.table = table
+ self.select = select
+
+ @compiles(InsertFromSelect)
+ def visit_insert_from_select(element, compiler, **kw):
+ return "INSERT INTO %s (%s)" % (
+ compiler.process(element.table, asfrom=True, **kw),
+ compiler.process(element.select, **kw)
+ )
+
+ insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5))
+ print(insert)
+
+Produces::
+
+ "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z
+ FROM mytable WHERE mytable.x > :x_1)"
+
+.. note::
+
+ The above ``InsertFromSelect`` construct is only an example, this actual
+ functionality is already available using the
+ :meth:`_expression.Insert.from_select` method.
+
+.. note::
+
+ The above ``InsertFromSelect`` construct probably wants to have "autocommit"
+ enabled. See :ref:`enabling_compiled_autocommit` for this step.
+
+Cross Compiling between SQL and DDL compilers
+---------------------------------------------
+
+SQL and DDL constructs are each compiled using different base compilers -
+``SQLCompiler`` and ``DDLCompiler``. A common need is to access the
+compilation rules of SQL expressions from within a DDL expression. The
+``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as
+below where we generate a CHECK constraint that embeds a SQL expression::
+
+ @compiles(MyConstraint)
+ def compile_my_constraint(constraint, ddlcompiler, **kw):
+ kw['literal_binds'] = True
+ return "CONSTRAINT %s CHECK (%s)" % (
+ constraint.name,
+ ddlcompiler.sql_compiler.process(
+ constraint.expression, **kw)
+ )
+
+Above, we add an additional flag to the process step as called by
+:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This
+indicates that any SQL expression which refers to a :class:`.BindParameter`
+object or other "literal" object such as those which refer to strings or
+integers should be rendered **in-place**, rather than being referred to as
+a bound parameter; when emitting DDL, bound parameters are typically not
+supported.
+
+
+.. _enabling_compiled_autocommit:
+
+Enabling Autocommit on a Construct
+==================================
+
+Recall from the section :ref:`autocommit` that the :class:`_engine.Engine`,
+when
+asked to execute a construct in the absence of a user-defined transaction,
+detects if the given construct represents DML or DDL, that is, a data
+modification or data definition statement, which requires (or may require,
+in the case of DDL) that the transaction generated by the DBAPI be committed
+(recall that DBAPI always has a transaction going on regardless of what
+SQLAlchemy does). Checking for this is actually accomplished by checking for
+the "autocommit" execution option on the construct. When building a
+construct like an INSERT derivation, a new DDL type, or perhaps a stored
+procedure that alters data, the "autocommit" option needs to be set in order
+for the statement to function with "connectionless" execution
+(as described in :ref:`dbengine_implicit`).
+
+Currently a quick way to do this is to subclass :class:`.Executable`, then
+add the "autocommit" flag to the ``_execution_options`` dictionary (note this
+is a "frozen" dictionary which supplies a generative ``union()`` method)::
+
+ from sqlalchemy.sql.expression import Executable, ClauseElement
+
+ class MyInsertThing(Executable, ClauseElement):
+ _execution_options = \
+ Executable._execution_options.union({'autocommit': True})
+
+More succinctly, if the construct is truly similar to an INSERT, UPDATE, or
+DELETE, :class:`.UpdateBase` can be used, which already is a subclass
+of :class:`.Executable`, :class:`_expression.ClauseElement` and includes the
+``autocommit`` flag::
+
+ from sqlalchemy.sql.expression import UpdateBase
+
+ class MyInsertThing(UpdateBase):
+ def __init__(self, ...):
+ ...
+
+
+
+
+DDL elements that subclass :class:`.DDLElement` already have the
+"autocommit" flag turned on.
+
+
+
+
+Changing the default compilation of existing constructs
+=======================================================
+
+The compiler extension applies just as well to the existing constructs. When
+overriding the compilation of a built in SQL construct, the @compiles
+decorator is invoked upon the appropriate class (be sure to use the class,
+i.e. ``Insert`` or ``Select``, instead of the creation function such
+as ``insert()`` or ``select()``).
+
+Within the new compilation function, to get at the "original" compilation
+routine, use the appropriate visit_XXX method - this
+because compiler.process() will call upon the overriding routine and cause
+an endless loop. Such as, to add "prefix" to all insert statements::
+
+ from sqlalchemy.sql.expression import Insert
+
+ @compiles(Insert)
+ def prefix_inserts(insert, compiler, **kw):
+ return compiler.visit_insert(insert.prefix_with("some prefix"), **kw)
+
+The above compiler will prefix all INSERT statements with "some prefix" when
+compiled.
+
+.. _type_compilation_extension:
+
+Changing Compilation of Types
+=============================
+
+``compiler`` works for types, too, such as below where we implement the
+MS-SQL specific 'max' keyword for ``String``/``VARCHAR``::
+
+ @compiles(String, 'mssql')
+ @compiles(VARCHAR, 'mssql')
+ def compile_varchar(element, compiler, **kw):
+ if element.length == 'max':
+ return "VARCHAR('max')"
+ else:
+ return compiler.visit_VARCHAR(element, **kw)
+
+ foo = Table('foo', metadata,
+ Column('data', VARCHAR('max'))
+ )
+
+Subclassing Guidelines
+======================
+
+A big part of using the compiler extension is subclassing SQLAlchemy
+expression constructs. To make this easier, the expression and
+schema packages feature a set of "bases" intended for common tasks.
+A synopsis is as follows:
+
+* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root
+ expression class. Any SQL expression can be derived from this base, and is
+ probably the best choice for longer constructs such as specialized INSERT
+ statements.
+
+* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all
+ "column-like" elements. Anything that you'd place in the "columns" clause of
+ a SELECT statement (as well as order by and group by) can derive from this -
+ the object will automatically have Python "comparison" behavior.
+
+ :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a
+ ``type`` member which is expression's return type. This can be established
+ at the instance level in the constructor, or at the class level if its
+ generally constant::
+
+ class timestamp(ColumnElement):
+ type = TIMESTAMP()
+ inherit_cache = True
+
+* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a
+ ``ColumnElement`` and a "from clause" like object, and represents a SQL
+ function or stored procedure type of call. Since most databases support
+ statements along the line of "SELECT FROM <some function>"
+ ``FunctionElement`` adds in the ability to be used in the FROM clause of a
+ ``select()`` construct::
+
+ from sqlalchemy.sql.expression import FunctionElement
+
+ class coalesce(FunctionElement):
+ name = 'coalesce'
+ inherit_cache = True
+
+ @compiles(coalesce)
+ def compile(element, compiler, **kw):
+ return "coalesce(%s)" % compiler.process(element.clauses, **kw)
+
+ @compiles(coalesce, 'oracle')
+ def compile(element, compiler, **kw):
+ if len(element.clauses) > 2:
+ raise TypeError("coalesce only supports two arguments on Oracle")
+ return "nvl(%s)" % compiler.process(element.clauses, **kw)
+
+* :class:`.DDLElement` - The root of all DDL expressions,
+ like CREATE TABLE, ALTER TABLE, etc. Compilation of :class:`.DDLElement`
+ subclasses is issued by a :class:`.DDLCompiler` instead of a
+ :class:`.SQLCompiler`. :class:`.DDLElement` can also be used as an event hook
+ in conjunction with event hooks like :meth:`.DDLEvents.before_create` and
+ :meth:`.DDLEvents.after_create`, allowing the construct to be invoked
+ automatically during CREATE TABLE and DROP TABLE sequences.
+
+ .. seealso::
+
+ :ref:`metadata_ddl_toplevel` - contains examples of associating
+ :class:`.DDL` objects (which are themselves :class:`.DDLElement`
+ instances) with :class:`.DDLEvents` event hooks.
+
+* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which
+ should be used with any expression class that represents a "standalone"
+ SQL statement that can be passed directly to an ``execute()`` method. It
+ is already implicit within ``DDLElement`` and ``FunctionElement``.
+
+Most of the above constructs also respond to SQL statement caching. A
+subclassed construct will want to define the caching behavior for the object,
+which usually means setting the flag ``inherit_cache`` to the value of
+``False`` or ``True``. See the next section :ref:`compilerext_caching`
+for background.
+
+
+.. _compilerext_caching:
+
+Enabling Caching Support for Custom Constructs
+==============================================
+
+SQLAlchemy as of version 1.4 includes a
+:ref:`SQL compilation caching facility <sql_caching>` which will allow
+equivalent SQL constructs to cache their stringified form, along with other
+structural information used to fetch results from the statement.
+
+For reasons discussed at :ref:`caching_caveats`, the implementation of this
+caching system takes a conservative approach towards including custom SQL
+constructs and/or subclasses within the caching system. This includes that
+any user-defined SQL constructs, including all the examples for this
+extension, will not participate in caching by default unless they positively
+assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache`
+attribute when set to ``True`` at the class level of a specific subclass
+will indicate that instances of this class may be safely cached, using the
+cache key generation scheme of the immediate superclass. This applies
+for example to the "synopsis" example indicated previously::
+
+ class MyColumn(ColumnClause):
+ inherit_cache = True
+
+ @compiles(MyColumn)
+ def compile_mycolumn(element, compiler, **kw):
+ return "[%s]" % element.name
+
+Above, the ``MyColumn`` class does not include any new state that
+affects its SQL compilation; the cache key of ``MyColumn`` instances will
+make use of that of the ``ColumnClause`` superclass, meaning it will take
+into account the class of the object (``MyColumn``), the string name and
+datatype of the object::
+
+ >>> MyColumn("some_name", String())._generate_cache_key()
+ CacheKey(
+ key=('0', <class '__main__.MyColumn'>,
+ 'name', 'some_name',
+ 'type', (<class 'sqlalchemy.sql.sqltypes.String'>,
+ ('length', None), ('collation', None))
+ ), bindparams=[])
+
+For objects that are likely to be **used liberally as components within many
+larger statements**, such as :class:`_schema.Column` subclasses and custom SQL
+datatypes, it's important that **caching be enabled as much as possible**, as
+this may otherwise negatively affect performance.
+
+An example of an object that **does** contain state which affects its SQL
+compilation is the one illustrated at :ref:`compilerext_compiling_subelements`;
+this is an "INSERT FROM SELECT" construct that combines together a
+:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of
+which independently affect the SQL string generation of the construct. For
+this class, the example illustrates that it simply does not participate in
+caching::
+
+ class InsertFromSelect(Executable, ClauseElement):
+ inherit_cache = False
+
+ def __init__(self, table, select):
+ self.table = table
+ self.select = select
+
+ @compiles(InsertFromSelect)
+ def visit_insert_from_select(element, compiler, **kw):
+ return "INSERT INTO %s (%s)" % (
+ compiler.process(element.table, asfrom=True, **kw),
+ compiler.process(element.select, **kw)
+ )
+
+While it is also possible that the above ``InsertFromSelect`` could be made to
+produce a cache key that is composed of that of the :class:`_schema.Table` and
+:class:`_sql.Select` components together, the API for this is not at the moment
+fully public. However, for an "INSERT FROM SELECT" construct, which is only
+used by itself for specific operations, caching is not as critical as in the
+previous example.
+
+For objects that are **used in relative isolation and are generally
+standalone**, such as custom :term:`DML` constructs like an "INSERT FROM
+SELECT", **caching is generally less critical** as the lack of caching for such
+a construct will have only localized implications for that specific operation.
+
+
+Further Examples
+================
+
+"UTC timestamp" function
+-------------------------
+
+A function that works like "CURRENT_TIMESTAMP" except applies the
+appropriate conversions so that the time is in UTC time. Timestamps are best
+stored in relational databases as UTC, without time zones. UTC so that your
+database doesn't think time has gone backwards in the hour when daylight
+savings ends, without timezones because timezones are like character
+encodings - they're best applied only at the endpoints of an application
+(i.e. convert to UTC upon user input, re-apply desired timezone upon display).
+
+For PostgreSQL and Microsoft SQL Server::
+
+ from sqlalchemy.sql import expression
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.types import DateTime
+
+ class utcnow(expression.FunctionElement):
+ type = DateTime()
+ inherit_cache = True
+
+ @compiles(utcnow, 'postgresql')
+ def pg_utcnow(element, compiler, **kw):
+ return "TIMEZONE('utc', CURRENT_TIMESTAMP)"
+
+ @compiles(utcnow, 'mssql')
+ def ms_utcnow(element, compiler, **kw):
+ return "GETUTCDATE()"
+
+Example usage::
+
+ from sqlalchemy import (
+ Table, Column, Integer, String, DateTime, MetaData
+ )
+ metadata = MetaData()
+ event = Table("event", metadata,
+ Column("id", Integer, primary_key=True),
+ Column("description", String(50), nullable=False),
+ Column("timestamp", DateTime, server_default=utcnow())
+ )
+
+"GREATEST" function
+-------------------
+
+The "GREATEST" function is given any number of arguments and returns the one
+that is of the highest value - its equivalent to Python's ``max``
+function. A SQL standard version versus a CASE based version which only
+accommodates two arguments::
+
+ from sqlalchemy.sql import expression, case
+ from sqlalchemy.ext.compiler import compiles
+ from sqlalchemy.types import Numeric
+
+ class greatest(expression.FunctionElement):
+ type = Numeric()
+ name = 'greatest'
+ inherit_cache = True
+
+ @compiles(greatest)
+ def default_greatest(element, compiler, **kw):
+ return compiler.visit_function(element)
+
+ @compiles(greatest, 'sqlite')
+ @compiles(greatest, 'mssql')
+ @compiles(greatest, 'oracle')
+ def case_greatest(element, compiler, **kw):
+ arg1, arg2 = list(element.clauses)
+ return compiler.process(case([(arg1 > arg2, arg1)], else_=arg2), **kw)
+
+Example usage::
+
+ Session.query(Account).\
+ filter(
+ greatest(
+ Account.checking_balance,
+ Account.savings_balance) > 10000
+ )
+
+"false" expression
+------------------
+
+Render a "false" constant expression, rendering as "0" on platforms that
+don't have a "false" constant::
+
+ from sqlalchemy.sql import expression
+ from sqlalchemy.ext.compiler import compiles
+
+ class sql_false(expression.ColumnElement):
+ inherit_cache = True
+
+ @compiles(sql_false)
+ def default_false(element, compiler, **kw):
+ return "false"
+
+ @compiles(sql_false, 'mssql')
+ @compiles(sql_false, 'mysql')
+ @compiles(sql_false, 'oracle')
+ def int_false(element, compiler, **kw):
+ return "0"
+
+Example usage::
+
+ from sqlalchemy import select, union_all
+
+ exp = union_all(
+ select(users.c.name, sql_false().label("enrolled")),
+ select(customers.c.name, customers.c.enrolled)
+ )
+
+"""
+from .. import exc
+from .. import util
+from ..sql import sqltypes
+
+
+def compiles(class_, *specs):
+ """Register a function as a compiler for a
+ given :class:`_expression.ClauseElement` type."""
+
+ def decorate(fn):
+ # get an existing @compiles handler
+ existing = class_.__dict__.get("_compiler_dispatcher", None)
+
+ # get the original handler. All ClauseElement classes have one
+ # of these, but some TypeEngine classes will not.
+ existing_dispatch = getattr(class_, "_compiler_dispatch", None)
+
+ if not existing:
+ existing = _dispatcher()
+
+ if existing_dispatch:
+
+ def _wrap_existing_dispatch(element, compiler, **kw):
+ try:
+ return existing_dispatch(element, compiler, **kw)
+ except exc.UnsupportedCompilationError as uce:
+ util.raise_(
+ exc.UnsupportedCompilationError(
+ compiler,
+ type(element),
+ message="%s construct has no default "
+ "compilation handler." % type(element),
+ ),
+ from_=uce,
+ )
+
+ existing.specs["default"] = _wrap_existing_dispatch
+
+ # TODO: why is the lambda needed ?
+ setattr(
+ class_,
+ "_compiler_dispatch",
+ lambda *arg, **kw: existing(*arg, **kw),
+ )
+ setattr(class_, "_compiler_dispatcher", existing)
+
+ if specs:
+ for s in specs:
+ existing.specs[s] = fn
+
+ else:
+ existing.specs["default"] = fn
+ return fn
+
+ return decorate
+
+
+def deregister(class_):
+ """Remove all custom compilers associated with a given
+ :class:`_expression.ClauseElement` type.
+
+ """
+
+ if hasattr(class_, "_compiler_dispatcher"):
+ class_._compiler_dispatch = class_._original_compiler_dispatch
+ del class_._compiler_dispatcher
+
+
+class _dispatcher(object):
+ def __init__(self):
+ self.specs = {}
+
+ def __call__(self, element, compiler, **kw):
+ # TODO: yes, this could also switch off of DBAPI in use.
+ fn = self.specs.get(compiler.dialect.name, None)
+ if not fn:
+ try:
+ fn = self.specs["default"]
+ except KeyError as ke:
+ util.raise_(
+ exc.UnsupportedCompilationError(
+ compiler,
+ type(element),
+ message="%s construct has no default "
+ "compilation handler." % type(element),
+ ),
+ replace_context=ke,
+ )
+
+ # if compilation includes add_to_result_map, collect add_to_result_map
+ # arguments from the user-defined callable, which are probably none
+ # because this is not public API. if it wasn't called, then call it
+ # ourselves.
+ arm = kw.get("add_to_result_map", None)
+ if arm:
+ arm_collection = []
+ kw["add_to_result_map"] = lambda *args: arm_collection.append(args)
+
+ expr = fn(element, compiler, **kw)
+
+ if arm:
+ if not arm_collection:
+ arm_collection.append(
+ (None, None, (element,), sqltypes.NULLTYPE)
+ )
+ for tup in arm_collection:
+ arm(*tup)
+ return expr
diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py
new file mode 100644
index 0000000..6215e35
--- /dev/null
+++ b/lib/sqlalchemy/ext/declarative/__init__.py
@@ -0,0 +1,64 @@
+# ext/declarative/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from .extensions import AbstractConcreteBase
+from .extensions import ConcreteBase
+from .extensions import DeferredReflection
+from .extensions import instrument_declarative
+from ... import util
+from ...orm.decl_api import as_declarative as _as_declarative
+from ...orm.decl_api import declarative_base as _declarative_base
+from ...orm.decl_api import DeclarativeMeta
+from ...orm.decl_api import declared_attr
+from ...orm.decl_api import has_inherited_table as _has_inherited_table
+from ...orm.decl_api import synonym_for as _synonym_for
+
+
+@util.moved_20(
+ "The ``declarative_base()`` function is now available as "
+ ":func:`sqlalchemy.orm.declarative_base`."
+)
+def declarative_base(*arg, **kw):
+ return _declarative_base(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``as_declarative()`` function is now available as "
+ ":func:`sqlalchemy.orm.as_declarative`"
+)
+def as_declarative(*arg, **kw):
+ return _as_declarative(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``has_inherited_table()`` function is now available as "
+ ":func:`sqlalchemy.orm.has_inherited_table`."
+)
+def has_inherited_table(*arg, **kw):
+ return _has_inherited_table(*arg, **kw)
+
+
+@util.moved_20(
+ "The ``synonym_for()`` function is now available as "
+ ":func:`sqlalchemy.orm.synonym_for`"
+)
+def synonym_for(*arg, **kw):
+ return _synonym_for(*arg, **kw)
+
+
+__all__ = [
+ "declarative_base",
+ "synonym_for",
+ "has_inherited_table",
+ "instrument_declarative",
+ "declared_attr",
+ "as_declarative",
+ "ConcreteBase",
+ "AbstractConcreteBase",
+ "DeclarativeMeta",
+ "DeferredReflection",
+]
diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py
new file mode 100644
index 0000000..7818841
--- /dev/null
+++ b/lib/sqlalchemy/ext/declarative/extensions.py
@@ -0,0 +1,463 @@
+# ext/declarative/extensions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+"""Public API functions and helpers for declarative."""
+
+
+from ... import inspection
+from ... import util
+from ...orm import exc as orm_exc
+from ...orm import registry
+from ...orm import relationships
+from ...orm.base import _mapper_or_none
+from ...orm.clsregistry import _resolver
+from ...orm.decl_base import _DeferredMapperConfig
+from ...orm.util import polymorphic_union
+from ...schema import Table
+from ...util import OrderedDict
+
+
+@util.deprecated(
+ "2.0",
+ "the instrument_declarative function is deprecated "
+ "and will be removed in SQLAlhcemy 2.0. Please use "
+ ":meth:`_orm.registry.map_declaratively",
+)
+def instrument_declarative(cls, cls_registry, metadata):
+ """Given a class, configure the class declaratively,
+ using the given registry, which can be any dictionary, and
+ MetaData object.
+
+ """
+ registry(metadata=metadata, class_registry=cls_registry).map_declaratively(
+ cls
+ )
+
+
+class ConcreteBase(object):
+ """A helper class for 'concrete' declarative mappings.
+
+ :class:`.ConcreteBase` will use the :func:`.polymorphic_union`
+ function automatically, against all tables mapped as a subclass
+ to this class. The function is called via the
+ ``__declare_last__()`` function, which is essentially
+ a hook for the :meth:`.after_configured` event.
+
+ :class:`.ConcreteBase` produces a mapped
+ table for the class itself. Compare to :class:`.AbstractConcreteBase`,
+ which does not.
+
+ Example::
+
+ from sqlalchemy.ext.declarative import ConcreteBase
+
+ class Employee(ConcreteBase, Base):
+ __tablename__ = 'employee'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ __mapper_args__ = {
+ 'polymorphic_identity':'employee',
+ 'concrete':True}
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ manager_data = Column(String(40))
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+
+ The name of the discriminator column used by :func:`.polymorphic_union`
+ defaults to the name ``type``. To suit the use case of a mapping where an
+ actual column in a mapped table is already named ``type``, the
+ discriminator name can be configured by setting the
+ ``_concrete_discriminator_name`` attribute::
+
+ class Employee(ConcreteBase, Base):
+ _concrete_discriminator_name = '_concrete_discriminator'
+
+ .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name``
+ attribute to :class:`_declarative.ConcreteBase` so that the
+ virtual discriminator column name can be customized.
+
+ .. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute
+ need only be placed on the basemost class to take correct effect for
+ all subclasses. An explicit error message is now raised if the
+ mapped column names conflict with the discriminator name, whereas
+ in the 1.3.x series there would be some warnings and then a non-useful
+ query would be generated.
+
+ .. seealso::
+
+ :class:`.AbstractConcreteBase`
+
+ :ref:`concrete_inheritance`
+
+
+ """
+
+ @classmethod
+ def _create_polymorphic_union(cls, mappers, discriminator_name):
+ return polymorphic_union(
+ OrderedDict(
+ (mp.polymorphic_identity, mp.local_table) for mp in mappers
+ ),
+ discriminator_name,
+ "pjoin",
+ )
+
+ @classmethod
+ def __declare_first__(cls):
+ m = cls.__mapper__
+ if m.with_polymorphic:
+ return
+
+ discriminator_name = (
+ getattr(cls, "_concrete_discriminator_name", None) or "type"
+ )
+
+ mappers = list(m.self_and_descendants)
+ pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
+ m._set_with_polymorphic(("*", pjoin))
+ m._set_polymorphic_on(pjoin.c[discriminator_name])
+
+
+class AbstractConcreteBase(ConcreteBase):
+ """A helper class for 'concrete' declarative mappings.
+
+ :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union`
+ function automatically, against all tables mapped as a subclass
+ to this class. The function is called via the
+ ``__declare_last__()`` function, which is essentially
+ a hook for the :meth:`.after_configured` event.
+
+ :class:`.AbstractConcreteBase` does produce a mapped class
+ for the base class, however it is not persisted to any table; it
+ is instead mapped directly to the "polymorphic" selectable directly
+ and is only used for selecting. Compare to :class:`.ConcreteBase`,
+ which does create a persisted table for the base class.
+
+ .. note::
+
+ The :class:`.AbstractConcreteBase` class does not intend to set up the
+ mapping for the base class until all the subclasses have been defined,
+ as it needs to create a mapping against a selectable that will include
+ all subclass tables. In order to achieve this, it waits for the
+ **mapper configuration event** to occur, at which point it scans
+ through all the configured subclasses and sets up a mapping that will
+ query against all subclasses at once.
+
+ While this event is normally invoked automatically, in the case of
+ :class:`.AbstractConcreteBase`, it may be necessary to invoke it
+ explicitly after **all** subclass mappings are defined, if the first
+ operation is to be a query against this base class. To do so, invoke
+ :func:`.configure_mappers` once all the desired classes have been
+ configured::
+
+ from sqlalchemy.orm import configure_mappers
+
+ configure_mappers()
+
+ .. seealso::
+
+ :func:`_orm.configure_mappers`
+
+
+ Example::
+
+ from sqlalchemy.ext.declarative import AbstractConcreteBase
+
+ class Employee(AbstractConcreteBase, Base):
+ pass
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+ employee_id = Column(Integer, primary_key=True)
+ name = Column(String(50))
+ manager_data = Column(String(40))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+ configure_mappers()
+
+ The abstract base class is handled by declarative in a special way;
+ at class configuration time, it behaves like a declarative mixin
+ or an ``__abstract__`` base class. Once classes are configured
+ and mappings are produced, it then gets mapped itself, but
+ after all of its descendants. This is a very unique system of mapping
+ not found in any other SQLAlchemy system.
+
+ Using this approach, we can specify columns and properties
+ that will take place on mapped subclasses, in the way that
+ we normally do as in :ref:`declarative_mixins`::
+
+ class Company(Base):
+ __tablename__ = 'company'
+ id = Column(Integer, primary_key=True)
+
+ class Employee(AbstractConcreteBase, Base):
+ employee_id = Column(Integer, primary_key=True)
+
+ @declared_attr
+ def company_id(cls):
+ return Column(ForeignKey('company.id'))
+
+ @declared_attr
+ def company(cls):
+ return relationship("Company")
+
+ class Manager(Employee):
+ __tablename__ = 'manager'
+
+ name = Column(String(50))
+ manager_data = Column(String(40))
+
+ __mapper_args__ = {
+ 'polymorphic_identity':'manager',
+ 'concrete':True}
+
+ configure_mappers()
+
+ When we make use of our mappings however, both ``Manager`` and
+ ``Employee`` will have an independently usable ``.company`` attribute::
+
+ session.query(Employee).filter(Employee.company.has(id=5))
+
+ .. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase`
+ have been reworked to support relationships established directly
+ on the abstract base, without any special configurational steps.
+
+ .. seealso::
+
+ :class:`.ConcreteBase`
+
+ :ref:`concrete_inheritance`
+
+ """
+
+ __no_table__ = True
+
+ @classmethod
+ def __declare_first__(cls):
+ cls._sa_decl_prepare_nocascade()
+
+ @classmethod
+ def _sa_decl_prepare_nocascade(cls):
+ if getattr(cls, "__mapper__", None):
+ return
+
+ to_map = _DeferredMapperConfig.config_for_cls(cls)
+
+ # can't rely on 'self_and_descendants' here
+ # since technically an immediate subclass
+ # might not be mapped, but a subclass
+ # may be.
+ mappers = []
+ stack = list(cls.__subclasses__())
+ while stack:
+ klass = stack.pop()
+ stack.extend(klass.__subclasses__())
+ mn = _mapper_or_none(klass)
+ if mn is not None:
+ mappers.append(mn)
+
+ discriminator_name = (
+ getattr(cls, "_concrete_discriminator_name", None) or "type"
+ )
+ pjoin = cls._create_polymorphic_union(mappers, discriminator_name)
+
+ # For columns that were declared on the class, these
+ # are normally ignored with the "__no_table__" mapping,
+ # unless they have a different attribute key vs. col name
+ # and are in the properties argument.
+ # In that case, ensure we update the properties entry
+ # to the correct column from the pjoin target table.
+ declared_cols = set(to_map.declared_columns)
+ for k, v in list(to_map.properties.items()):
+ if v in declared_cols:
+ to_map.properties[k] = pjoin.c[v.key]
+
+ to_map.local_table = pjoin
+
+ m_args = to_map.mapper_args_fn or dict
+
+ def mapper_args():
+ args = m_args()
+ args["polymorphic_on"] = pjoin.c[discriminator_name]
+ return args
+
+ to_map.mapper_args_fn = mapper_args
+
+ m = to_map.map()
+
+ for scls in cls.__subclasses__():
+ sm = _mapper_or_none(scls)
+ if sm and sm.concrete and cls in scls.__bases__:
+ sm._set_concrete_base(m)
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of AbstractConcreteBase and "
+ "has a mapping pending until all subclasses are defined. "
+ "Call the sqlalchemy.orm.configure_mappers() function after "
+ "all subclasses have been defined to "
+ "complete the mapping of this class."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+
+class DeferredReflection(object):
+ """A helper class for construction of mappings based on
+ a deferred reflection step.
+
+ Normally, declarative can be used with reflection by
+ setting a :class:`_schema.Table` object using autoload_with=engine
+ as the ``__table__`` attribute on a declarative class.
+ The caveat is that the :class:`_schema.Table` must be fully
+ reflected, or at the very least have a primary key column,
+ at the point at which a normal declarative mapping is
+ constructed, meaning the :class:`_engine.Engine` must be available
+ at class declaration time.
+
+ The :class:`.DeferredReflection` mixin moves the construction
+ of mappers to be at a later point, after a specific
+ method is called which first reflects all :class:`_schema.Table`
+ objects created so far. Classes can define it as such::
+
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.declarative import DeferredReflection
+ Base = declarative_base()
+
+ class MyClass(DeferredReflection, Base):
+ __tablename__ = 'mytable'
+
+ Above, ``MyClass`` is not yet mapped. After a series of
+ classes have been defined in the above fashion, all tables
+ can be reflected and mappings created using
+ :meth:`.prepare`::
+
+ engine = create_engine("someengine://...")
+ DeferredReflection.prepare(engine)
+
+ The :class:`.DeferredReflection` mixin can be applied to individual
+ classes, used as the base for the declarative base itself,
+ or used in a custom abstract class. Using an abstract base
+ allows that only a subset of classes to be prepared for a
+ particular prepare step, which is necessary for applications
+ that use more than one engine. For example, if an application
+ has two engines, you might use two bases, and prepare each
+ separately, e.g.::
+
+ class ReflectedOne(DeferredReflection, Base):
+ __abstract__ = True
+
+ class ReflectedTwo(DeferredReflection, Base):
+ __abstract__ = True
+
+ class MyClass(ReflectedOne):
+ __tablename__ = 'mytable'
+
+ class MyOtherClass(ReflectedOne):
+ __tablename__ = 'myothertable'
+
+ class YetAnotherClass(ReflectedTwo):
+ __tablename__ = 'yetanothertable'
+
+ # ... etc.
+
+ Above, the class hierarchies for ``ReflectedOne`` and
+ ``ReflectedTwo`` can be configured separately::
+
+ ReflectedOne.prepare(engine_one)
+ ReflectedTwo.prepare(engine_two)
+
+ .. seealso::
+
+ :ref:`orm_declarative_reflected_deferred_reflection` - in the
+ :ref:`orm_declarative_table_config_toplevel` section.
+
+ """
+
+ @classmethod
+ def prepare(cls, engine):
+ """Reflect all :class:`_schema.Table` objects for all current
+ :class:`.DeferredReflection` subclasses"""
+
+ to_map = _DeferredMapperConfig.classes_for_base(cls)
+
+ with inspection.inspect(engine)._inspection_context() as insp:
+ for thingy in to_map:
+ cls._sa_decl_prepare(thingy.local_table, insp)
+ thingy.map()
+ mapper = thingy.cls.__mapper__
+ metadata = mapper.class_.metadata
+ for rel in mapper._props.values():
+ if (
+ isinstance(rel, relationships.RelationshipProperty)
+ and rel.secondary is not None
+ ):
+ if isinstance(rel.secondary, Table):
+ cls._reflect_table(rel.secondary, insp)
+ elif isinstance(rel.secondary, str):
+
+ _, resolve_arg = _resolver(rel.parent.class_, rel)
+
+ rel.secondary = resolve_arg(rel.secondary)
+ rel.secondary._resolvers += (
+ cls._sa_deferred_table_resolver(
+ insp, metadata
+ ),
+ )
+
+ # controversy! do we resolve it here? or leave
+ # it deferred? I think doing it here is necessary
+ # so the connection does not leak.
+ rel.secondary = rel.secondary()
+
+ @classmethod
+ def _sa_deferred_table_resolver(cls, inspector, metadata):
+ def _resolve(key):
+ t1 = Table(key, metadata)
+ cls._reflect_table(t1, inspector)
+ return t1
+
+ return _resolve
+
+ @classmethod
+ def _sa_decl_prepare(cls, local_table, inspector):
+ # autoload Table, which is already
+ # present in the metadata. This
+ # will fill in db-loaded columns
+ # into the existing Table object.
+ if local_table is not None:
+ cls._reflect_table(local_table, inspector)
+
+ @classmethod
+ def _sa_raise_deferred_config(cls):
+ raise orm_exc.UnmappedClassError(
+ cls,
+ msg="Class %s is a subclass of DeferredReflection. "
+ "Mappings are not produced until the .prepare() "
+ "method is called on the class hierarchy."
+ % orm_exc._safe_cls_name(cls),
+ )
+
+ @classmethod
+ def _reflect_table(cls, table, inspector):
+ Table(
+ table.name,
+ table.metadata,
+ extend_existing=True,
+ autoload_replace=False,
+ autoload_with=inspector,
+ schema=table.schema,
+ )
diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py
new file mode 100644
index 0000000..bad076e
--- /dev/null
+++ b/lib/sqlalchemy/ext/horizontal_shard.py
@@ -0,0 +1,256 @@
+# ext/horizontal_shard.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Horizontal sharding support.
+
+Defines a rudimental 'horizontal sharding' system which allows a Session to
+distribute queries and persistence operations across multiple databases.
+
+For a usage example, see the :ref:`examples_sharding` example included in
+the source distribution.
+
+"""
+
+from .. import event
+from .. import exc
+from .. import inspect
+from .. import util
+from ..orm.query import Query
+from ..orm.session import Session
+
+__all__ = ["ShardedSession", "ShardedQuery"]
+
+
+class ShardedQuery(Query):
+ def __init__(self, *args, **kwargs):
+ super(ShardedQuery, self).__init__(*args, **kwargs)
+ self.id_chooser = self.session.id_chooser
+ self.query_chooser = self.session.query_chooser
+ self.execute_chooser = self.session.execute_chooser
+ self._shard_id = None
+
+ def set_shard(self, shard_id):
+ """Return a new query, limited to a single shard ID.
+
+ All subsequent operations with the returned query will
+ be against the single shard regardless of other state.
+
+ The shard_id can be passed for a 2.0 style execution to the
+ bind_arguments dictionary of :meth:`.Session.execute`::
+
+ results = session.execute(
+ stmt,
+ bind_arguments={"shard_id": "my_shard"}
+ )
+
+ """
+ return self.execution_options(_sa_shard_id=shard_id)
+
+
+class ShardedSession(Session):
+ def __init__(
+ self,
+ shard_chooser,
+ id_chooser,
+ execute_chooser=None,
+ shards=None,
+ query_cls=ShardedQuery,
+ **kwargs
+ ):
+ """Construct a ShardedSession.
+
+ :param shard_chooser: A callable which, passed a Mapper, a mapped
+ instance, and possibly a SQL clause, returns a shard ID. This id
+ may be based off of the attributes present within the object, or on
+ some round-robin scheme. If the scheme is based on a selection, it
+ should set whatever state on the instance to mark it in the future as
+ participating in that shard.
+
+ :param id_chooser: A callable, passed a query and a tuple of identity
+ values, which should return a list of shard ids where the ID might
+ reside. The databases will be queried in the order of this listing.
+
+ :param execute_chooser: For a given :class:`.ORMExecuteState`,
+ returns the list of shard_ids
+ where the query should be issued. Results from all shards returned
+ will be combined together into a single listing.
+
+ .. versionchanged:: 1.4 The ``execute_chooser`` parameter
+ supersedes the ``query_chooser`` parameter.
+
+ :param shards: A dictionary of string shard names
+ to :class:`~sqlalchemy.engine.Engine` objects.
+
+ """
+ query_chooser = kwargs.pop("query_chooser", None)
+ super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs)
+
+ event.listen(
+ self, "do_orm_execute", execute_and_instances, retval=True
+ )
+ self.shard_chooser = shard_chooser
+ self.id_chooser = id_chooser
+
+ if query_chooser:
+ util.warn_deprecated(
+ "The ``query_choser`` parameter is deprecated; "
+ "please use ``execute_chooser``.",
+ "1.4",
+ )
+ if execute_chooser:
+ raise exc.ArgumentError(
+ "Can't pass query_chooser and execute_chooser "
+ "at the same time."
+ )
+
+ def execute_chooser(orm_context):
+ return query_chooser(orm_context.statement)
+
+ self.execute_chooser = execute_chooser
+ else:
+ self.execute_chooser = execute_chooser
+ self.query_chooser = query_chooser
+ self.__binds = {}
+ if shards is not None:
+ for k in shards:
+ self.bind_shard(k, shards[k])
+
+ def _identity_lookup(
+ self,
+ mapper,
+ primary_key_identity,
+ identity_token=None,
+ lazy_loaded_from=None,
+ **kw
+ ):
+ """override the default :meth:`.Session._identity_lookup` method so
+ that we search for a given non-token primary key identity across all
+ possible identity tokens (e.g. shard ids).
+
+ .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from
+ the :class:`_query.Query` object to the :class:`.Session`.
+
+ """
+
+ if identity_token is not None:
+ return super(ShardedSession, self)._identity_lookup(
+ mapper,
+ primary_key_identity,
+ identity_token=identity_token,
+ **kw
+ )
+ else:
+ q = self.query(mapper)
+ if lazy_loaded_from:
+ q = q._set_lazyload_from(lazy_loaded_from)
+ for shard_id in self.id_chooser(q, primary_key_identity):
+ obj = super(ShardedSession, self)._identity_lookup(
+ mapper,
+ primary_key_identity,
+ identity_token=shard_id,
+ lazy_loaded_from=lazy_loaded_from,
+ **kw
+ )
+ if obj is not None:
+ return obj
+
+ return None
+
+ def _choose_shard_and_assign(self, mapper, instance, **kw):
+ if instance is not None:
+ state = inspect(instance)
+ if state.key:
+ token = state.key[2]
+ assert token is not None
+ return token
+ elif state.identity_token:
+ return state.identity_token
+
+ shard_id = self.shard_chooser(mapper, instance, **kw)
+ if instance is not None:
+ state.identity_token = shard_id
+ return shard_id
+
+ def connection_callable(
+ self, mapper=None, instance=None, shard_id=None, **kwargs
+ ):
+ """Provide a :class:`_engine.Connection` to use in the unit of work
+ flush process.
+
+ """
+
+ if shard_id is None:
+ shard_id = self._choose_shard_and_assign(mapper, instance)
+
+ if self.in_transaction():
+ return self.get_transaction().connection(mapper, shard_id=shard_id)
+ else:
+ return self.get_bind(
+ mapper, shard_id=shard_id, instance=instance
+ ).connect(**kwargs)
+
+ def get_bind(
+ self, mapper=None, shard_id=None, instance=None, clause=None, **kw
+ ):
+ if shard_id is None:
+ shard_id = self._choose_shard_and_assign(
+ mapper, instance, clause=clause
+ )
+ return self.__binds[shard_id]
+
+ def bind_shard(self, shard_id, bind):
+ self.__binds[shard_id] = bind
+
+
+def execute_and_instances(orm_context):
+ if orm_context.is_select:
+ load_options = active_options = orm_context.load_options
+ update_options = None
+
+ elif orm_context.is_update or orm_context.is_delete:
+ load_options = None
+ update_options = active_options = orm_context.update_delete_options
+ else:
+ load_options = update_options = active_options = None
+
+ session = orm_context.session
+
+ def iter_for_shard(shard_id, load_options, update_options):
+ execution_options = dict(orm_context.local_execution_options)
+
+ bind_arguments = dict(orm_context.bind_arguments)
+ bind_arguments["shard_id"] = shard_id
+
+ if orm_context.is_select:
+ load_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_load_options"] = load_options
+ elif orm_context.is_update or orm_context.is_delete:
+ update_options += {"_refresh_identity_token": shard_id}
+ execution_options["_sa_orm_update_options"] = update_options
+
+ return orm_context.invoke_statement(
+ bind_arguments=bind_arguments, execution_options=execution_options
+ )
+
+ if active_options and active_options._refresh_identity_token is not None:
+ shard_id = active_options._refresh_identity_token
+ elif "_sa_shard_id" in orm_context.execution_options:
+ shard_id = orm_context.execution_options["_sa_shard_id"]
+ elif "shard_id" in orm_context.bind_arguments:
+ shard_id = orm_context.bind_arguments["shard_id"]
+ else:
+ shard_id = None
+
+ if shard_id is not None:
+ return iter_for_shard(shard_id, load_options, update_options)
+ else:
+ partial = []
+ for shard_id in session.execute_chooser(orm_context):
+ result_ = iter_for_shard(shard_id, load_options, update_options)
+ partial.append(result_)
+
+ return partial[0].merge(*partial[1:])
diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py
new file mode 100644
index 0000000..cc0aca6
--- /dev/null
+++ b/lib/sqlalchemy/ext/hybrid.py
@@ -0,0 +1,1206 @@
+# ext/hybrid.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Define attributes on ORM-mapped classes that have "hybrid" behavior.
+
+"hybrid" means the attribute has distinct behaviors defined at the
+class level and at the instance level.
+
+The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of
+method decorator, is around 50 lines of code and has almost no
+dependencies on the rest of SQLAlchemy. It can, in theory, work with
+any descriptor-based expression system.
+
+Consider a mapping ``Interval``, representing integer ``start`` and ``end``
+values. We can define higher level functions on mapped classes that produce SQL
+expressions at the class level, and Python expression evaluation at the
+instance level. Below, each function decorated with :class:`.hybrid_method` or
+:class:`.hybrid_property` may receive ``self`` as an instance of the class, or
+as the class itself::
+
+ from sqlalchemy import Column, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.orm import Session, aliased
+ from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
+
+ Base = declarative_base()
+
+ class Interval(Base):
+ __tablename__ = 'interval'
+
+ id = Column(Integer, primary_key=True)
+ start = Column(Integer, nullable=False)
+ end = Column(Integer, nullable=False)
+
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @hybrid_method
+ def contains(self, point):
+ return (self.start <= point) & (point <= self.end)
+
+ @hybrid_method
+ def intersects(self, other):
+ return self.contains(other.start) | self.contains(other.end)
+
+Above, the ``length`` property returns the difference between the
+``end`` and ``start`` attributes. With an instance of ``Interval``,
+this subtraction occurs in Python, using normal Python descriptor
+mechanics::
+
+ >>> i1 = Interval(5, 10)
+ >>> i1.length
+ 5
+
+When dealing with the ``Interval`` class itself, the :class:`.hybrid_property`
+descriptor evaluates the function body given the ``Interval`` class as
+the argument, which when evaluated with SQLAlchemy expression mechanics
+(here using the :attr:`.QueryableAttribute.expression` accessor)
+returns a new SQL expression::
+
+ >>> print(Interval.length.expression)
+ interval."end" - interval.start
+
+ >>> print(Session().query(Interval).filter(Interval.length > 10))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval."end" - interval.start > :param_1
+
+ORM methods such as :meth:`_query.Query.filter_by`
+generally use ``getattr()`` to
+locate attributes, so can also be used with hybrid attributes::
+
+ >>> print(Session().query(Interval).filter_by(length=5))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval."end" - interval.start = :param_1
+
+The ``Interval`` class example also illustrates two methods,
+``contains()`` and ``intersects()``, decorated with
+:class:`.hybrid_method`. This decorator applies the same idea to
+methods that :class:`.hybrid_property` applies to attributes. The
+methods return boolean values, and take advantage of the Python ``|``
+and ``&`` bitwise operators to produce equivalent instance-level and
+SQL expression-level boolean behavior::
+
+ >>> i1.contains(6)
+ True
+ >>> i1.contains(15)
+ False
+ >>> i1.intersects(Interval(7, 18))
+ True
+ >>> i1.intersects(Interval(25, 29))
+ False
+
+ >>> print(Session().query(Interval).filter(Interval.contains(15)))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE interval.start <= :start_1 AND interval."end" > :end_1
+
+ >>> ia = aliased(Interval)
+ >>> print(Session().query(Interval, ia).filter(Interval.intersects(ia)))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end, interval_1.id AS interval_1_id,
+ interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end
+ FROM interval, interval AS interval_1
+ WHERE interval.start <= interval_1.start
+ AND interval."end" > interval_1.start
+ OR interval.start <= interval_1."end"
+ AND interval."end" > interval_1."end"
+
+.. _hybrid_distinct_expression:
+
+Defining Expression Behavior Distinct from Attribute Behavior
+--------------------------------------------------------------
+
+Our usage of the ``&`` and ``|`` bitwise operators above was
+fortunate, considering our functions operated on two boolean values to
+return a new one. In many cases, the construction of an in-Python
+function and a SQLAlchemy SQL expression have enough differences that
+two separate Python expressions should be defined. The
+:mod:`~sqlalchemy.ext.hybrid` decorators define the
+:meth:`.hybrid_property.expression` modifier for this purpose. As an
+example we'll define the radius of the interval, which requires the
+usage of the absolute value function::
+
+ from sqlalchemy import func
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def radius(self):
+ return abs(self.length) / 2
+
+ @radius.expression
+ def radius(cls):
+ return func.abs(cls.length) / 2
+
+Above the Python function ``abs()`` is used for instance-level
+operations, the SQL function ``ABS()`` is used via the :data:`.func`
+object for class-level expressions::
+
+ >>> i1.radius
+ 2
+
+ >>> print(Session().query(Interval).filter(Interval.radius > 5))
+ SELECT interval.id AS interval_id, interval.start AS interval_start,
+ interval."end" AS interval_end
+ FROM interval
+ WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1
+
+.. note:: When defining an expression for a hybrid property or method, the
+ expression method **must** retain the name of the original hybrid, else
+ the new hybrid with the additional state will be attached to the class
+ with the non-matching name. To use the example above::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def radius(self):
+ return abs(self.length) / 2
+
+ # WRONG - the non-matching name will cause this function to be
+ # ignored
+ @radius.expression
+ def radius_expression(cls):
+ return func.abs(cls.length) / 2
+
+ This is also true for other mutator methods, such as
+ :meth:`.hybrid_property.update_expression`. This is the same behavior
+ as that of the ``@property`` construct that is part of standard Python.
+
+Defining Setters
+----------------
+
+Hybrid properties can also define setter methods. If we wanted
+``length`` above, when set, to modify the endpoint value::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @length.setter
+ def length(self, value):
+ self.end = self.start + value
+
+The ``length(self, value)`` method is now called upon set::
+
+ >>> i1 = Interval(5, 10)
+ >>> i1.length
+ 5
+ >>> i1.length = 12
+ >>> i1.end
+ 17
+
+.. _hybrid_bulk_update:
+
+Allowing Bulk ORM Update
+------------------------
+
+A hybrid can define a custom "UPDATE" handler for when using the
+:meth:`_query.Query.update` method, allowing the hybrid to be used in the
+SET clause of the update.
+
+Normally, when using a hybrid with :meth:`_query.Query.update`, the SQL
+expression is used as the column that's the target of the SET. If our
+``Interval`` class had a hybrid ``start_point`` that linked to
+``Interval.start``, this could be substituted directly::
+
+ session.query(Interval).update({Interval.start_point: 10})
+
+However, when using a composite hybrid like ``Interval.length``, this
+hybrid represents more than one column. We can set up a handler that will
+accommodate a value passed to :meth:`_query.Query.update` which can affect
+this, using the :meth:`.hybrid_property.update_expression` decorator.
+A handler that works similarly to our setter would be::
+
+ class Interval(object):
+ # ...
+
+ @hybrid_property
+ def length(self):
+ return self.end - self.start
+
+ @length.setter
+ def length(self, value):
+ self.end = self.start + value
+
+ @length.update_expression
+ def length(cls, value):
+ return [
+ (cls.end, cls.start + value)
+ ]
+
+Above, if we use ``Interval.length`` in an UPDATE expression as::
+
+ session.query(Interval).update(
+ {Interval.length: 25}, synchronize_session='fetch')
+
+We'll get an UPDATE statement along the lines of::
+
+ UPDATE interval SET end=start + :value
+
+In some cases, the default "evaluate" strategy can't perform the SET
+expression in Python; while the addition operator we're using above
+is supported, for more complex SET expressions it will usually be necessary
+to use either the "fetch" or False synchronization strategy as illustrated
+above.
+
+.. note:: For ORM bulk updates to work with hybrids, the function name
+ of the hybrid must match that of how it is accessed. Something
+ like this wouldn't work::
+
+ class Interval(object):
+ # ...
+
+ def _get(self):
+ return self.end - self.start
+
+ def _set(self, value):
+ self.end = self.start + value
+
+ def _update_expr(cls, value):
+ return [
+ (cls.end, cls.start + value)
+ ]
+
+ length = hybrid_property(
+ fget=_get, fset=_set, update_expr=_update_expr
+ )
+
+ The Python descriptor protocol does not provide any reliable way for
+ a descriptor to know what attribute name it was accessed as, and
+ the UPDATE scheme currently relies upon being able to access the
+ attribute from an instance by name in order to perform the instance
+ synchronization step.
+
+.. versionadded:: 1.2 added support for bulk updates to hybrid properties.
+
+Working with Relationships
+--------------------------
+
+There's no essential difference when creating hybrids that work with
+related objects as opposed to column-based data. The need for distinct
+expressions tends to be greater. The two variants we'll illustrate
+are the "join-dependent" hybrid, and the "correlated subquery" hybrid.
+
+Join-Dependent Relationship Hybrid
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Consider the following declarative
+mapping which relates a ``User`` to a ``SavingsAccount``::
+
+ from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ Base = declarative_base()
+
+ class SavingsAccount(Base):
+ __tablename__ = 'account'
+ id = Column(Integer, primary_key=True)
+ user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
+ balance = Column(Numeric(15, 5))
+
+ class User(Base):
+ __tablename__ = 'user'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(100), nullable=False)
+
+ accounts = relationship("SavingsAccount", backref="owner")
+
+ @hybrid_property
+ def balance(self):
+ if self.accounts:
+ return self.accounts[0].balance
+ else:
+ return None
+
+ @balance.setter
+ def balance(self, value):
+ if not self.accounts:
+ account = Account(owner=self)
+ else:
+ account = self.accounts[0]
+ account.balance = value
+
+ @balance.expression
+ def balance(cls):
+ return SavingsAccount.balance
+
+The above hybrid property ``balance`` works with the first
+``SavingsAccount`` entry in the list of accounts for this user. The
+in-Python getter/setter methods can treat ``accounts`` as a Python
+list available on ``self``.
+
+However, at the expression level, it's expected that the ``User`` class will
+be used in an appropriate context such that an appropriate join to
+``SavingsAccount`` will be present::
+
+ >>> print(Session().query(User, User.balance).
+ ... join(User.accounts).filter(User.balance > 5000))
+ SELECT "user".id AS user_id, "user".name AS user_name,
+ account.balance AS account_balance
+ FROM "user" JOIN account ON "user".id = account.user_id
+ WHERE account.balance > :balance_1
+
+Note however, that while the instance level accessors need to worry
+about whether ``self.accounts`` is even present, this issue expresses
+itself differently at the SQL expression level, where we basically
+would use an outer join::
+
+ >>> from sqlalchemy import or_
+ >>> print (Session().query(User, User.balance).outerjoin(User.accounts).
+ ... filter(or_(User.balance < 5000, User.balance == None)))
+ SELECT "user".id AS user_id, "user".name AS user_name,
+ account.balance AS account_balance
+ FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id
+ WHERE account.balance < :balance_1 OR account.balance IS NULL
+
+Correlated Subquery Relationship Hybrid
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+We can, of course, forego being dependent on the enclosing query's usage
+of joins in favor of the correlated subquery, which can portably be packed
+into a single column expression. A correlated subquery is more portable, but
+often performs more poorly at the SQL level. Using the same technique
+illustrated at :ref:`mapper_column_property_sql_expressions`,
+we can adjust our ``SavingsAccount`` example to aggregate the balances for
+*all* accounts, and use a correlated subquery for the column expression::
+
+ from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.hybrid import hybrid_property
+ from sqlalchemy import select, func
+
+ Base = declarative_base()
+
+ class SavingsAccount(Base):
+ __tablename__ = 'account'
+ id = Column(Integer, primary_key=True)
+ user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
+ balance = Column(Numeric(15, 5))
+
+ class User(Base):
+ __tablename__ = 'user'
+ id = Column(Integer, primary_key=True)
+ name = Column(String(100), nullable=False)
+
+ accounts = relationship("SavingsAccount", backref="owner")
+
+ @hybrid_property
+ def balance(self):
+ return sum(acc.balance for acc in self.accounts)
+
+ @balance.expression
+ def balance(cls):
+ return select(func.sum(SavingsAccount.balance)).\
+ where(SavingsAccount.user_id==cls.id).\
+ label('total_balance')
+
+The above recipe will give us the ``balance`` column which renders
+a correlated SELECT::
+
+ >>> print(s.query(User).filter(User.balance > 400))
+ SELECT "user".id AS user_id, "user".name AS user_name
+ FROM "user"
+ WHERE (SELECT sum(account.balance) AS sum_1
+ FROM account
+ WHERE account.user_id = "user".id) > :param_1
+
+.. _hybrid_custom_comparators:
+
+Building Custom Comparators
+---------------------------
+
+The hybrid property also includes a helper that allows construction of
+custom comparators. A comparator object allows one to customize the
+behavior of each SQLAlchemy expression operator individually. They
+are useful when creating custom types that have some highly
+idiosyncratic behavior on the SQL side.
+
+.. note:: The :meth:`.hybrid_property.comparator` decorator introduced
+ in this section **replaces** the use of the
+ :meth:`.hybrid_property.expression` decorator.
+ They cannot be used together.
+
+The example class below allows case-insensitive comparisons on the attribute
+named ``word_insensitive``::
+
+ from sqlalchemy.ext.hybrid import Comparator, hybrid_property
+ from sqlalchemy import func, Column, Integer, String
+ from sqlalchemy.orm import Session
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class CaseInsensitiveComparator(Comparator):
+ def __eq__(self, other):
+ return func.lower(self.__clause_element__()) == func.lower(other)
+
+ class SearchWord(Base):
+ __tablename__ = 'searchword'
+ id = Column(Integer, primary_key=True)
+ word = Column(String(255), nullable=False)
+
+ @hybrid_property
+ def word_insensitive(self):
+ return self.word.lower()
+
+ @word_insensitive.comparator
+ def word_insensitive(cls):
+ return CaseInsensitiveComparator(cls.word)
+
+Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()``
+SQL function to both sides::
+
+ >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks"))
+ SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
+ FROM searchword
+ WHERE lower(searchword.word) = lower(:lower_1)
+
+The ``CaseInsensitiveComparator`` above implements part of the
+:class:`.ColumnOperators` interface. A "coercion" operation like
+lowercasing can be applied to all comparison operations (i.e. ``eq``,
+``lt``, ``gt``, etc.) using :meth:`.Operators.operate`::
+
+ class CaseInsensitiveComparator(Comparator):
+ def operate(self, op, other, **kwargs):
+ return op(
+ func.lower(self.__clause_element__()),
+ func.lower(other),
+ **kwargs,
+ )
+
+.. _hybrid_reuse_subclass:
+
+Reusing Hybrid Properties across Subclasses
+-------------------------------------------
+
+A hybrid can be referred to from a superclass, to allow modifying
+methods like :meth:`.hybrid_property.getter`, :meth:`.hybrid_property.setter`
+to be used to redefine those methods on a subclass. This is similar to
+how the standard Python ``@property`` object works::
+
+ class FirstNameOnly(Base):
+ # ...
+
+ first_name = Column(String)
+
+ @hybrid_property
+ def name(self):
+ return self.first_name
+
+ @name.setter
+ def name(self, value):
+ self.first_name = value
+
+ class FirstNameLastName(FirstNameOnly):
+ # ...
+
+ last_name = Column(String)
+
+ @FirstNameOnly.name.getter
+ def name(self):
+ return self.first_name + ' ' + self.last_name
+
+ @name.setter
+ def name(self, value):
+ self.first_name, self.last_name = value.split(' ', 1)
+
+Above, the ``FirstNameLastName`` class refers to the hybrid from
+``FirstNameOnly.name`` to repurpose its getter and setter for the subclass.
+
+When overriding :meth:`.hybrid_property.expression` and
+:meth:`.hybrid_property.comparator` alone as the first reference to the
+superclass, these names conflict with the same-named accessors on the class-
+level :class:`.QueryableAttribute` object returned at the class level. To
+override these methods when referring directly to the parent class descriptor,
+add the special qualifier :attr:`.hybrid_property.overrides`, which will de-
+reference the instrumented attribute back to the hybrid object::
+
+ class FirstNameLastName(FirstNameOnly):
+ # ...
+
+ last_name = Column(String)
+
+ @FirstNameOnly.name.overrides.expression
+ def name(cls):
+ return func.concat(cls.first_name, ' ', cls.last_name)
+
+.. versionadded:: 1.2 Added :meth:`.hybrid_property.getter` as well as the
+ ability to redefine accessors per-subclass.
+
+
+Hybrid Value Objects
+--------------------
+
+Note in our previous example, if we were to compare the ``word_insensitive``
+attribute of a ``SearchWord`` instance to a plain Python string, the plain
+Python string would not be coerced to lower case - the
+``CaseInsensitiveComparator`` we built, being returned by
+``@word_insensitive.comparator``, only applies to the SQL side.
+
+A more comprehensive form of the custom comparator is to construct a *Hybrid
+Value Object*. This technique applies the target value or expression to a value
+object which is then returned by the accessor in all cases. The value object
+allows control of all operations upon the value as well as how compared values
+are treated, both on the SQL expression side as well as the Python value side.
+Replacing the previous ``CaseInsensitiveComparator`` class with a new
+``CaseInsensitiveWord`` class::
+
+ class CaseInsensitiveWord(Comparator):
+ "Hybrid value representing a lower case representation of a word."
+
+ def __init__(self, word):
+ if isinstance(word, basestring):
+ self.word = word.lower()
+ elif isinstance(word, CaseInsensitiveWord):
+ self.word = word.word
+ else:
+ self.word = func.lower(word)
+
+ def operate(self, op, other, **kwargs):
+ if not isinstance(other, CaseInsensitiveWord):
+ other = CaseInsensitiveWord(other)
+ return op(self.word, other.word, **kwargs)
+
+ def __clause_element__(self):
+ return self.word
+
+ def __str__(self):
+ return self.word
+
+ key = 'word'
+ "Label to apply to Query tuple results"
+
+Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may
+be a SQL function, or may be a Python native. By overriding ``operate()`` and
+``__clause_element__()`` to work in terms of ``self.word``, all comparison
+operations will work against the "converted" form of ``word``, whether it be
+SQL side or Python side. Our ``SearchWord`` class can now deliver the
+``CaseInsensitiveWord`` object unconditionally from a single hybrid call::
+
+ class SearchWord(Base):
+ __tablename__ = 'searchword'
+ id = Column(Integer, primary_key=True)
+ word = Column(String(255), nullable=False)
+
+ @hybrid_property
+ def word_insensitive(self):
+ return CaseInsensitiveWord(self.word)
+
+The ``word_insensitive`` attribute now has case-insensitive comparison behavior
+universally, including SQL expression vs. Python expression (note the Python
+value is converted to lower case on the Python side here)::
+
+ >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks"))
+ SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
+ FROM searchword
+ WHERE lower(searchword.word) = :lower_1
+
+SQL expression versus SQL expression::
+
+ >>> sw1 = aliased(SearchWord)
+ >>> sw2 = aliased(SearchWord)
+ >>> print(Session().query(
+ ... sw1.word_insensitive,
+ ... sw2.word_insensitive).\
+ ... filter(
+ ... sw1.word_insensitive > sw2.word_insensitive
+ ... ))
+ SELECT lower(searchword_1.word) AS lower_1,
+ lower(searchword_2.word) AS lower_2
+ FROM searchword AS searchword_1, searchword AS searchword_2
+ WHERE lower(searchword_1.word) > lower(searchword_2.word)
+
+Python only expression::
+
+ >>> ws1 = SearchWord(word="SomeWord")
+ >>> ws1.word_insensitive == "sOmEwOrD"
+ True
+ >>> ws1.word_insensitive == "XOmEwOrX"
+ False
+ >>> print(ws1.word_insensitive)
+ someword
+
+The Hybrid Value pattern is very useful for any kind of value that may have
+multiple representations, such as timestamps, time deltas, units of
+measurement, currencies and encrypted passwords.
+
+.. seealso::
+
+ `Hybrids and Value Agnostic Types
+ <https://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/>`_
+ - on the techspot.zzzeek.org blog
+
+ `Value Agnostic Types, Part II
+ <https://techspot.zzzeek.org/2011/10/29/value-agnostic-types-part-ii/>`_ -
+ on the techspot.zzzeek.org blog
+
+.. _hybrid_transformers:
+
+Building Transformers
+----------------------
+
+A *transformer* is an object which can receive a :class:`_query.Query`
+object and
+return a new one. The :class:`_query.Query` object includes a method
+:meth:`.with_transformation` that returns a new :class:`_query.Query`
+transformed by
+the given function.
+
+We can combine this with the :class:`.Comparator` class to produce one type
+of recipe which can both set up the FROM clause of a query as well as assign
+filtering criterion.
+
+Consider a mapped class ``Node``, which assembles using adjacency list into a
+hierarchical tree pattern::
+
+ from sqlalchemy import Column, Integer, ForeignKey
+ from sqlalchemy.orm import relationship
+ from sqlalchemy.ext.declarative import declarative_base
+ Base = declarative_base()
+
+ class Node(Base):
+ __tablename__ = 'node'
+ id = Column(Integer, primary_key=True)
+ parent_id = Column(Integer, ForeignKey('node.id'))
+ parent = relationship("Node", remote_side=id)
+
+Suppose we wanted to add an accessor ``grandparent``. This would return the
+``parent`` of ``Node.parent``. When we have an instance of ``Node``, this is
+simple::
+
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ class Node(Base):
+ # ...
+
+ @hybrid_property
+ def grandparent(self):
+ return self.parent.parent
+
+For the expression, things are not so clear. We'd need to construct a
+:class:`_query.Query` where we :meth:`_query.Query.join` twice along
+``Node.parent`` to get to the ``grandparent``. We can instead return a
+transforming callable that we'll combine with the :class:`.Comparator` class to
+receive any :class:`_query.Query` object, and return a new one that's joined to
+the ``Node.parent`` attribute and filtered based on the given criterion::
+
+ from sqlalchemy.ext.hybrid import Comparator
+
+ class GrandparentTransformer(Comparator):
+ def operate(self, op, other, **kwargs):
+ def transform(q):
+ cls = self.__clause_element__()
+ parent_alias = aliased(cls)
+ return q.join(parent_alias, cls.parent).filter(
+ op(parent_alias.parent, other, **kwargs)
+ )
+
+ return transform
+
+ Base = declarative_base()
+
+ class Node(Base):
+ __tablename__ = 'node'
+ id = Column(Integer, primary_key=True)
+ parent_id = Column(Integer, ForeignKey('node.id'))
+ parent = relationship("Node", remote_side=id)
+
+ @hybrid_property
+ def grandparent(self):
+ return self.parent.parent
+
+ @grandparent.comparator
+ def grandparent(cls):
+ return GrandparentTransformer(cls)
+
+The ``GrandparentTransformer`` overrides the core :meth:`.Operators.operate`
+method at the base of the :class:`.Comparator` hierarchy to return a query-
+transforming callable, which then runs the given comparison operation in a
+particular context. Such as, in the example above, the ``operate`` method is
+called, given the :attr:`.Operators.eq` callable as well as the right side of
+the comparison ``Node(id=5)``. A function ``transform`` is then returned which
+will transform a :class:`_query.Query` first to join to ``Node.parent``,
+then to
+compare ``parent_alias`` using :attr:`.Operators.eq` against the left and right
+sides, passing into :meth:`_query.Query.filter`:
+
+.. sourcecode:: pycon+sql
+
+ >>> from sqlalchemy.orm import Session
+ >>> session = Session()
+ {sql}>>> session.query(Node).\
+ ... with_transformation(Node.grandparent==Node(id=5)).\
+ ... all()
+ SELECT node.id AS node_id, node.parent_id AS node_parent_id
+ FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
+ WHERE :param_1 = node_1.parent_id
+ {stop}
+
+We can modify the pattern to be more verbose but flexible by separating the
+"join" step from the "filter" step. The tricky part here is ensuring that
+successive instances of ``GrandparentTransformer`` use the same
+:class:`.AliasedClass` object against ``Node``. Below we use a simple
+memoizing approach that associates a ``GrandparentTransformer`` with each
+class::
+
+ class Node(Base):
+
+ # ...
+
+ @grandparent.comparator
+ def grandparent(cls):
+ # memoize a GrandparentTransformer
+ # per class
+ if '_gp' not in cls.__dict__:
+ cls._gp = GrandparentTransformer(cls)
+ return cls._gp
+
+ class GrandparentTransformer(Comparator):
+
+ def __init__(self, cls):
+ self.parent_alias = aliased(cls)
+
+ @property
+ def join(self):
+ def go(q):
+ return q.join(self.parent_alias, Node.parent)
+ return go
+
+ def operate(self, op, other, **kwargs):
+ return op(self.parent_alias.parent, other, **kwargs)
+
+.. sourcecode:: pycon+sql
+
+ {sql}>>> session.query(Node).\
+ ... with_transformation(Node.grandparent.join).\
+ ... filter(Node.grandparent==Node(id=5))
+ SELECT node.id AS node_id, node.parent_id AS node_parent_id
+ FROM node JOIN node AS node_1 ON node_1.id = node.parent_id
+ WHERE :param_1 = node_1.parent_id
+ {stop}
+
+The "transformer" pattern is an experimental pattern that starts to make usage
+of some functional programming paradigms. While it's only recommended for
+advanced and/or patient developers, there's probably a whole lot of amazing
+things it can be used for.
+
+""" # noqa
+from .. import util
+from ..orm import attributes
+from ..orm import interfaces
+
+HYBRID_METHOD = util.symbol("HYBRID_METHOD")
+"""Symbol indicating an :class:`InspectionAttr` that's
+ of type :class:`.hybrid_method`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_attributes`
+
+"""
+
+HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY")
+"""Symbol indicating an :class:`InspectionAttr` that's
+ of type :class:`.hybrid_method`.
+
+ Is assigned to the :attr:`.InspectionAttr.extension_type`
+ attribute.
+
+ .. seealso::
+
+ :attr:`_orm.Mapper.all_orm_attributes`
+
+"""
+
+
+class hybrid_method(interfaces.InspectionAttrInfo):
+ """A decorator which allows definition of a Python object method with both
+ instance-level and class-level behavior.
+
+ """
+
+ is_attribute = True
+ extension_type = HYBRID_METHOD
+
+ def __init__(self, func, expr=None):
+ """Create a new :class:`.hybrid_method`.
+
+ Usage is typically via decorator::
+
+ from sqlalchemy.ext.hybrid import hybrid_method
+
+ class SomeClass(object):
+ @hybrid_method
+ def value(self, x, y):
+ return self._value + x + y
+
+ @value.expression
+ def value(self, x, y):
+ return func.some_function(self._value, x, y)
+
+ """
+ self.func = func
+ self.expression(expr or func)
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self.expr.__get__(owner, owner.__class__)
+ else:
+ return self.func.__get__(instance, owner)
+
+ def expression(self, expr):
+ """Provide a modifying decorator that defines a
+ SQL-expression producing method."""
+
+ self.expr = expr
+ if not self.expr.__doc__:
+ self.expr.__doc__ = self.func.__doc__
+ return self
+
+
+class hybrid_property(interfaces.InspectionAttrInfo):
+ """A decorator which allows definition of a Python descriptor with both
+ instance-level and class-level behavior.
+
+ """
+
+ is_attribute = True
+ extension_type = HYBRID_PROPERTY
+
+ def __init__(
+ self,
+ fget,
+ fset=None,
+ fdel=None,
+ expr=None,
+ custom_comparator=None,
+ update_expr=None,
+ ):
+ """Create a new :class:`.hybrid_property`.
+
+ Usage is typically via decorator::
+
+ from sqlalchemy.ext.hybrid import hybrid_property
+
+ class SomeClass(object):
+ @hybrid_property
+ def value(self):
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ self._value = value
+
+ """
+ self.fget = fget
+ self.fset = fset
+ self.fdel = fdel
+ self.expr = expr
+ self.custom_comparator = custom_comparator
+ self.update_expr = update_expr
+ util.update_wrapper(self, fget)
+
+ def __get__(self, instance, owner):
+ if instance is None:
+ return self._expr_comparator(owner)
+ else:
+ return self.fget(instance)
+
+ def __set__(self, instance, value):
+ if self.fset is None:
+ raise AttributeError("can't set attribute")
+ self.fset(instance, value)
+
+ def __delete__(self, instance):
+ if self.fdel is None:
+ raise AttributeError("can't delete attribute")
+ self.fdel(instance)
+
+ def _copy(self, **kw):
+ defaults = {
+ key: value
+ for key, value in self.__dict__.items()
+ if not key.startswith("_")
+ }
+ defaults.update(**kw)
+ return type(self)(**defaults)
+
+ @property
+ def overrides(self):
+ """Prefix for a method that is overriding an existing attribute.
+
+ The :attr:`.hybrid_property.overrides` accessor just returns
+ this hybrid object, which when called at the class level from
+ a parent class, will de-reference the "instrumented attribute"
+ normally returned at this level, and allow modifying decorators
+ like :meth:`.hybrid_property.expression` and
+ :meth:`.hybrid_property.comparator`
+ to be used without conflicting with the same-named attributes
+ normally present on the :class:`.QueryableAttribute`::
+
+ class SuperClass(object):
+ # ...
+
+ @hybrid_property
+ def foobar(self):
+ return self._foobar
+
+ class SubClass(SuperClass):
+ # ...
+
+ @SuperClass.foobar.overrides.expression
+ def foobar(cls):
+ return func.subfoobar(self._foobar)
+
+ .. versionadded:: 1.2
+
+ .. seealso::
+
+ :ref:`hybrid_reuse_subclass`
+
+ """
+ return self
+
+ def getter(self, fget):
+ """Provide a modifying decorator that defines a getter method.
+
+ .. versionadded:: 1.2
+
+ """
+
+ return self._copy(fget=fget)
+
+ def setter(self, fset):
+ """Provide a modifying decorator that defines a setter method."""
+
+ return self._copy(fset=fset)
+
+ def deleter(self, fdel):
+ """Provide a modifying decorator that defines a deletion method."""
+
+ return self._copy(fdel=fdel)
+
+ def expression(self, expr):
+ """Provide a modifying decorator that defines a SQL-expression
+ producing method.
+
+ When a hybrid is invoked at the class level, the SQL expression given
+ here is wrapped inside of a specialized :class:`.QueryableAttribute`,
+ which is the same kind of object used by the ORM to represent other
+ mapped attributes. The reason for this is so that other class-level
+ attributes such as docstrings and a reference to the hybrid itself may
+ be maintained within the structure that's returned, without any
+ modifications to the original SQL expression passed in.
+
+ .. note::
+
+ When referring to a hybrid property from an owning class (e.g.
+ ``SomeClass.some_hybrid``), an instance of
+ :class:`.QueryableAttribute` is returned, representing the
+ expression or comparator object as well as this hybrid object.
+ However, that object itself has accessors called ``expression`` and
+ ``comparator``; so when attempting to override these decorators on a
+ subclass, it may be necessary to qualify it using the
+ :attr:`.hybrid_property.overrides` modifier first. See that
+ modifier for details.
+
+ .. seealso::
+
+ :ref:`hybrid_distinct_expression`
+
+ """
+
+ return self._copy(expr=expr)
+
+ def comparator(self, comparator):
+ """Provide a modifying decorator that defines a custom
+ comparator producing method.
+
+ The return value of the decorated method should be an instance of
+ :class:`~.hybrid.Comparator`.
+
+ .. note:: The :meth:`.hybrid_property.comparator` decorator
+ **replaces** the use of the :meth:`.hybrid_property.expression`
+ decorator. They cannot be used together.
+
+ When a hybrid is invoked at the class level, the
+ :class:`~.hybrid.Comparator` object given here is wrapped inside of a
+ specialized :class:`.QueryableAttribute`, which is the same kind of
+ object used by the ORM to represent other mapped attributes. The
+ reason for this is so that other class-level attributes such as
+ docstrings and a reference to the hybrid itself may be maintained
+ within the structure that's returned, without any modifications to the
+ original comparator object passed in.
+
+ .. note::
+
+ When referring to a hybrid property from an owning class (e.g.
+ ``SomeClass.some_hybrid``), an instance of
+ :class:`.QueryableAttribute` is returned, representing the
+ expression or comparator object as this hybrid object. However,
+ that object itself has accessors called ``expression`` and
+ ``comparator``; so when attempting to override these decorators on a
+ subclass, it may be necessary to qualify it using the
+ :attr:`.hybrid_property.overrides` modifier first. See that
+ modifier for details.
+
+ """
+ return self._copy(custom_comparator=comparator)
+
+ def update_expression(self, meth):
+ """Provide a modifying decorator that defines an UPDATE tuple
+ producing method.
+
+ The method accepts a single value, which is the value to be
+ rendered into the SET clause of an UPDATE statement. The method
+ should then process this value into individual column expressions
+ that fit into the ultimate SET clause, and return them as a
+ sequence of 2-tuples. Each tuple
+ contains a column expression as the key and a value to be rendered.
+
+ E.g.::
+
+ class Person(Base):
+ # ...
+
+ first_name = Column(String)
+ last_name = Column(String)
+
+ @hybrid_property
+ def fullname(self):
+ return first_name + " " + last_name
+
+ @fullname.update_expression
+ def fullname(cls, value):
+ fname, lname = value.split(" ", 1)
+ return [
+ (cls.first_name, fname),
+ (cls.last_name, lname)
+ ]
+
+ .. versionadded:: 1.2
+
+ """
+ return self._copy(update_expr=meth)
+
+ @util.memoized_property
+ def _expr_comparator(self):
+ if self.custom_comparator is not None:
+ return self._get_comparator(self.custom_comparator)
+ elif self.expr is not None:
+ return self._get_expr(self.expr)
+ else:
+ return self._get_expr(self.fget)
+
+ def _get_expr(self, expr):
+ def _expr(cls):
+ return ExprComparator(cls, expr(cls), self)
+
+ util.update_wrapper(_expr, expr)
+
+ return self._get_comparator(_expr)
+
+ def _get_comparator(self, comparator):
+
+ proxy_attr = attributes.create_proxied_attribute(self)
+
+ def expr_comparator(owner):
+ # because this is the descriptor protocol, we don't really know
+ # what our attribute name is. so search for it through the
+ # MRO.
+ for lookup in owner.__mro__:
+ if self.__name__ in lookup.__dict__:
+ if lookup.__dict__[self.__name__] is self:
+ name = self.__name__
+ break
+ else:
+ name = attributes.NO_KEY
+
+ return proxy_attr(
+ owner,
+ name,
+ self,
+ comparator(owner),
+ doc=comparator.__doc__ or self.__doc__,
+ )
+
+ return expr_comparator
+
+
+class Comparator(interfaces.PropComparator):
+ """A helper class that allows easy construction of custom
+ :class:`~.orm.interfaces.PropComparator`
+ classes for usage with hybrids."""
+
+ property = None
+
+ def __init__(self, expression):
+ self.expression = expression
+
+ def __clause_element__(self):
+ expr = self.expression
+ if hasattr(expr, "__clause_element__"):
+ expr = expr.__clause_element__()
+ return expr
+
+ def adapt_to_entity(self, adapt_to_entity):
+ # interesting....
+ return self
+
+
+class ExprComparator(Comparator):
+ def __init__(self, cls, expression, hybrid):
+ self.cls = cls
+ self.expression = expression
+ self.hybrid = hybrid
+
+ def __getattr__(self, key):
+ return getattr(self.expression, key)
+
+ @property
+ def info(self):
+ return self.hybrid.info
+
+ def _bulk_update_tuples(self, value):
+ if isinstance(self.expression, attributes.QueryableAttribute):
+ return self.expression._bulk_update_tuples(value)
+ elif self.hybrid.update_expr is not None:
+ return self.hybrid.update_expr(self.cls, value)
+ else:
+ return [(self.expression, value)]
+
+ @property
+ def property(self):
+ return self.expression.property
+
+ def operate(self, op, *other, **kwargs):
+ return op(self.expression, *other, **kwargs)
+
+ def reverse_operate(self, op, other, **kwargs):
+ return op(other, self.expression, **kwargs)
diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py
new file mode 100644
index 0000000..7cbac54
--- /dev/null
+++ b/lib/sqlalchemy/ext/indexable.py
@@ -0,0 +1,352 @@
+# ext/index.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Define attributes on ORM-mapped classes that have "index" attributes for
+columns with :class:`_types.Indexable` types.
+
+"index" means the attribute is associated with an element of an
+:class:`_types.Indexable` column with the predefined index to access it.
+The :class:`_types.Indexable` types include types such as
+:class:`_types.ARRAY`, :class:`_types.JSON` and
+:class:`_postgresql.HSTORE`.
+
+
+
+The :mod:`~sqlalchemy.ext.indexable` extension provides
+:class:`_schema.Column`-like interface for any element of an
+:class:`_types.Indexable` typed column. In simple cases, it can be
+treated as a :class:`_schema.Column` - mapped attribute.
+
+
+.. versionadded:: 1.1
+
+Synopsis
+========
+
+Given ``Person`` as a model with a primary key and JSON data field.
+While this field may have any number of elements encoded within it,
+we would like to refer to the element called ``name`` individually
+as a dedicated attribute which behaves like a standalone column::
+
+ from sqlalchemy import Column, JSON, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.indexable import index_property
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ name = index_property('data', 'name')
+
+
+Above, the ``name`` attribute now behaves like a mapped column. We
+can compose a new ``Person`` and set the value of ``name``::
+
+ >>> person = Person(name='Alchemist')
+
+The value is now accessible::
+
+ >>> person.name
+ 'Alchemist'
+
+Behind the scenes, the JSON field was initialized to a new blank dictionary
+and the field was set::
+
+ >>> person.data
+ {"name": "Alchemist'}
+
+The field is mutable in place::
+
+ >>> person.name = 'Renamed'
+ >>> person.name
+ 'Renamed'
+ >>> person.data
+ {'name': 'Renamed'}
+
+When using :class:`.index_property`, the change that we make to the indexable
+structure is also automatically tracked as history; we no longer need
+to use :class:`~.mutable.MutableDict` in order to track this change
+for the unit of work.
+
+Deletions work normally as well::
+
+ >>> del person.name
+ >>> person.data
+ {}
+
+Above, deletion of ``person.name`` deletes the value from the dictionary,
+but not the dictionary itself.
+
+A missing key will produce ``AttributeError``::
+
+ >>> person = Person()
+ >>> person.name
+ ...
+ AttributeError: 'name'
+
+Unless you set a default value::
+
+ >>> class Person(Base):
+ >>> __tablename__ = 'person'
+ >>>
+ >>> id = Column(Integer, primary_key=True)
+ >>> data = Column(JSON)
+ >>>
+ >>> name = index_property('data', 'name', default=None) # See default
+
+ >>> person = Person()
+ >>> print(person.name)
+ None
+
+
+The attributes are also accessible at the class level.
+Below, we illustrate ``Person.name`` used to generate
+an indexed SQL criteria::
+
+ >>> from sqlalchemy.orm import Session
+ >>> session = Session()
+ >>> query = session.query(Person).filter(Person.name == 'Alchemist')
+
+The above query is equivalent to::
+
+ >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist')
+
+Multiple :class:`.index_property` objects can be chained to produce
+multiple levels of indexing::
+
+ from sqlalchemy import Column, JSON, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.ext.indexable import index_property
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ birthday = index_property('data', 'birthday')
+ year = index_property('birthday', 'year')
+ month = index_property('birthday', 'month')
+ day = index_property('birthday', 'day')
+
+Above, a query such as::
+
+ q = session.query(Person).filter(Person.year == '1980')
+
+On a PostgreSQL backend, the above query will render as::
+
+ SELECT person.id, person.data
+ FROM person
+ WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s
+
+Default Values
+==============
+
+:class:`.index_property` includes special behaviors for when the indexed
+data structure does not exist, and a set operation is called:
+
+* For an :class:`.index_property` that is given an integer index value,
+ the default data structure will be a Python list of ``None`` values,
+ at least as long as the index value; the value is then set at its
+ place in the list. This means for an index value of zero, the list
+ will be initialized to ``[None]`` before setting the given value,
+ and for an index value of five, the list will be initialized to
+ ``[None, None, None, None, None]`` before setting the fifth element
+ to the given value. Note that an existing list is **not** extended
+ in place to receive a value.
+
+* for an :class:`.index_property` that is given any other kind of index
+ value (e.g. strings usually), a Python dictionary is used as the
+ default data structure.
+
+* The default data structure can be set to any Python callable using the
+ :paramref:`.index_property.datatype` parameter, overriding the previous
+ rules.
+
+
+Subclassing
+===========
+
+:class:`.index_property` can be subclassed, in particular for the common
+use case of providing coercion of values or SQL expressions as they are
+accessed. Below is a common recipe for use with a PostgreSQL JSON type,
+where we want to also include automatic casting plus ``astext()``::
+
+ class pg_json_property(index_property):
+ def __init__(self, attr_name, index, cast_type):
+ super(pg_json_property, self).__init__(attr_name, index)
+ self.cast_type = cast_type
+
+ def expr(self, model):
+ expr = super(pg_json_property, self).expr(model)
+ return expr.astext.cast(self.cast_type)
+
+The above subclass can be used with the PostgreSQL-specific
+version of :class:`_postgresql.JSON`::
+
+ from sqlalchemy import Column, Integer
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy.dialects.postgresql import JSON
+
+ Base = declarative_base()
+
+ class Person(Base):
+ __tablename__ = 'person'
+
+ id = Column(Integer, primary_key=True)
+ data = Column(JSON)
+
+ age = pg_json_property('data', 'age', Integer)
+
+The ``age`` attribute at the instance level works as before; however
+when rendering SQL, PostgreSQL's ``->>`` operator will be used
+for indexed access, instead of the usual index operator of ``->``::
+
+ >>> query = session.query(Person).filter(Person.age < 20)
+
+The above query will render::
+
+ SELECT person.id, person.data
+ FROM person
+ WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s
+
+""" # noqa
+from __future__ import absolute_import
+
+from .. import inspect
+from .. import util
+from ..ext.hybrid import hybrid_property
+from ..orm.attributes import flag_modified
+
+
+__all__ = ["index_property"]
+
+
+class index_property(hybrid_property): # noqa
+ """A property generator. The generated property describes an object
+ attribute that corresponds to an :class:`_types.Indexable`
+ column.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :mod:`sqlalchemy.ext.indexable`
+
+ """
+
+ _NO_DEFAULT_ARGUMENT = object()
+
+ def __init__(
+ self,
+ attr_name,
+ index,
+ default=_NO_DEFAULT_ARGUMENT,
+ datatype=None,
+ mutable=True,
+ onebased=True,
+ ):
+ """Create a new :class:`.index_property`.
+
+ :param attr_name:
+ An attribute name of an `Indexable` typed column, or other
+ attribute that returns an indexable structure.
+ :param index:
+ The index to be used for getting and setting this value. This
+ should be the Python-side index value for integers.
+ :param default:
+ A value which will be returned instead of `AttributeError`
+ when there is not a value at given index.
+ :param datatype: default datatype to use when the field is empty.
+ By default, this is derived from the type of index used; a
+ Python list for an integer index, or a Python dictionary for
+ any other style of index. For a list, the list will be
+ initialized to a list of None values that is at least
+ ``index`` elements long.
+ :param mutable: if False, writes and deletes to the attribute will
+ be disallowed.
+ :param onebased: assume the SQL representation of this value is
+ one-based; that is, the first index in SQL is 1, not zero.
+ """
+
+ if mutable:
+ super(index_property, self).__init__(
+ self.fget, self.fset, self.fdel, self.expr
+ )
+ else:
+ super(index_property, self).__init__(
+ self.fget, None, None, self.expr
+ )
+ self.attr_name = attr_name
+ self.index = index
+ self.default = default
+ is_numeric = isinstance(index, int)
+ onebased = is_numeric and onebased
+
+ if datatype is not None:
+ self.datatype = datatype
+ else:
+ if is_numeric:
+ self.datatype = lambda: [None for x in range(index + 1)]
+ else:
+ self.datatype = dict
+ self.onebased = onebased
+
+ def _fget_default(self, err=None):
+ if self.default == self._NO_DEFAULT_ARGUMENT:
+ util.raise_(AttributeError(self.attr_name), replace_context=err)
+ else:
+ return self.default
+
+ def fget(self, instance):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ return self._fget_default()
+ try:
+ value = column_value[self.index]
+ except (KeyError, IndexError) as err:
+ return self._fget_default(err)
+ else:
+ return value
+
+ def fset(self, instance, value):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name, None)
+ if column_value is None:
+ column_value = self.datatype()
+ setattr(instance, attr_name, column_value)
+ column_value[self.index] = value
+ setattr(instance, attr_name, column_value)
+ if attr_name in inspect(instance).mapper.attrs:
+ flag_modified(instance, attr_name)
+
+ def fdel(self, instance):
+ attr_name = self.attr_name
+ column_value = getattr(instance, attr_name)
+ if column_value is None:
+ raise AttributeError(self.attr_name)
+ try:
+ del column_value[self.index]
+ except KeyError as err:
+ util.raise_(AttributeError(self.attr_name), replace_context=err)
+ else:
+ setattr(instance, attr_name, column_value)
+ flag_modified(instance, attr_name)
+
+ def expr(self, model):
+ column = getattr(model, self.attr_name)
+ index = self.index
+ if self.onebased:
+ index += 1
+ return column[index]
diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py
new file mode 100644
index 0000000..54f3e64
--- /dev/null
+++ b/lib/sqlalchemy/ext/instrumentation.py
@@ -0,0 +1,416 @@
+"""Extensible class instrumentation.
+
+The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate
+systems of class instrumentation within the ORM. Class instrumentation
+refers to how the ORM places attributes on the class which maintain
+data and track changes to that data, as well as event hooks installed
+on the class.
+
+.. note::
+ The extension package is provided for the benefit of integration
+ with other object management packages, which already perform
+ their own instrumentation. It is not intended for general use.
+
+For examples of how the instrumentation extension is used,
+see the example :ref:`examples_instrumentation`.
+
+"""
+import weakref
+
+from .. import util
+from ..orm import attributes
+from ..orm import base as orm_base
+from ..orm import collections
+from ..orm import exc as orm_exc
+from ..orm import instrumentation as orm_instrumentation
+from ..orm.instrumentation import _default_dict_getter
+from ..orm.instrumentation import _default_manager_getter
+from ..orm.instrumentation import _default_state_getter
+from ..orm.instrumentation import ClassManager
+from ..orm.instrumentation import InstrumentationFactory
+
+
+INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__"
+"""Attribute, elects custom instrumentation when present on a mapped class.
+
+Allows a class to specify a slightly or wildly different technique for
+tracking changes made to mapped attributes and collections.
+
+Only one instrumentation implementation is allowed in a given object
+inheritance hierarchy.
+
+The value of this attribute must be a callable and will be passed a class
+object. The callable must return one of:
+
+ - An instance of an :class:`.InstrumentationManager` or subclass
+ - An object implementing all or some of InstrumentationManager (TODO)
+ - A dictionary of callables, implementing all or some of the above (TODO)
+ - An instance of a :class:`.ClassManager` or subclass
+
+This attribute is consulted by SQLAlchemy instrumentation
+resolution, once the :mod:`sqlalchemy.ext.instrumentation` module
+has been imported. If custom finders are installed in the global
+instrumentation_finders list, they may or may not choose to honor this
+attribute.
+
+"""
+
+
+def find_native_user_instrumentation_hook(cls):
+ """Find user-specified instrumentation management for a class."""
+ return getattr(cls, INSTRUMENTATION_MANAGER, None)
+
+
+instrumentation_finders = [find_native_user_instrumentation_hook]
+"""An extensible sequence of callables which return instrumentation
+implementations
+
+When a class is registered, each callable will be passed a class object.
+If None is returned, the
+next finder in the sequence is consulted. Otherwise the return must be an
+instrumentation factory that follows the same guidelines as
+sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER.
+
+By default, the only finder is find_native_user_instrumentation_hook, which
+searches for INSTRUMENTATION_MANAGER. If all finders return None, standard
+ClassManager instrumentation is used.
+
+"""
+
+
+class ExtendedInstrumentationRegistry(InstrumentationFactory):
+ """Extends :class:`.InstrumentationFactory` with additional
+ bookkeeping, to accommodate multiple types of
+ class managers.
+
+ """
+
+ _manager_finders = weakref.WeakKeyDictionary()
+ _state_finders = weakref.WeakKeyDictionary()
+ _dict_finders = weakref.WeakKeyDictionary()
+ _extended = False
+
+ def _locate_extended_factory(self, class_):
+ for finder in instrumentation_finders:
+ factory = finder(class_)
+ if factory is not None:
+ manager = self._extended_class_manager(class_, factory)
+ return manager, factory
+ else:
+ return None, None
+
+ def _check_conflicts(self, class_, factory):
+ existing_factories = self._collect_management_factories_for(
+ class_
+ ).difference([factory])
+ if existing_factories:
+ raise TypeError(
+ "multiple instrumentation implementations specified "
+ "in %s inheritance hierarchy: %r"
+ % (class_.__name__, list(existing_factories))
+ )
+
+ def _extended_class_manager(self, class_, factory):
+ manager = factory(class_)
+ if not isinstance(manager, ClassManager):
+ manager = _ClassInstrumentationAdapter(class_, manager)
+
+ if factory != ClassManager and not self._extended:
+ # somebody invoked a custom ClassManager.
+ # reinstall global "getter" functions with the more
+ # expensive ones.
+ self._extended = True
+ _install_instrumented_lookups()
+
+ self._manager_finders[class_] = manager.manager_getter()
+ self._state_finders[class_] = manager.state_getter()
+ self._dict_finders[class_] = manager.dict_getter()
+ return manager
+
+ def _collect_management_factories_for(self, cls):
+ """Return a collection of factories in play or specified for a
+ hierarchy.
+
+ Traverses the entire inheritance graph of a cls and returns a
+ collection of instrumentation factories for those classes. Factories
+ are extracted from active ClassManagers, if available, otherwise
+ instrumentation_finders is consulted.
+
+ """
+ hierarchy = util.class_hierarchy(cls)
+ factories = set()
+ for member in hierarchy:
+ manager = self.manager_of_class(member)
+ if manager is not None:
+ factories.add(manager.factory)
+ else:
+ for finder in instrumentation_finders:
+ factory = finder(member)
+ if factory is not None:
+ break
+ else:
+ factory = None
+ factories.add(factory)
+ factories.discard(None)
+ return factories
+
+ def unregister(self, class_):
+ super(ExtendedInstrumentationRegistry, self).unregister(class_)
+ if class_ in self._manager_finders:
+ del self._manager_finders[class_]
+ del self._state_finders[class_]
+ del self._dict_finders[class_]
+
+ def manager_of_class(self, cls):
+ if cls is None:
+ return None
+ try:
+ finder = self._manager_finders.get(cls, _default_manager_getter)
+ except TypeError:
+ # due to weakref lookup on invalid object
+ return None
+ else:
+ return finder(cls)
+
+ def state_of(self, instance):
+ if instance is None:
+ raise AttributeError("None has no persistent state.")
+ return self._state_finders.get(
+ instance.__class__, _default_state_getter
+ )(instance)
+
+ def dict_of(self, instance):
+ if instance is None:
+ raise AttributeError("None has no persistent state.")
+ return self._dict_finders.get(
+ instance.__class__, _default_dict_getter
+ )(instance)
+
+
+orm_instrumentation._instrumentation_factory = (
+ _instrumentation_factory
+) = ExtendedInstrumentationRegistry()
+orm_instrumentation.instrumentation_finders = instrumentation_finders
+
+
+class InstrumentationManager(object):
+ """User-defined class instrumentation extension.
+
+ :class:`.InstrumentationManager` can be subclassed in order
+ to change
+ how class instrumentation proceeds. This class exists for
+ the purposes of integration with other object management
+ frameworks which would like to entirely modify the
+ instrumentation methodology of the ORM, and is not intended
+ for regular usage. For interception of class instrumentation
+ events, see :class:`.InstrumentationEvents`.
+
+ The API for this class should be considered as semi-stable,
+ and may change slightly with new releases.
+
+ """
+
+ # r4361 added a mandatory (cls) constructor to this interface.
+ # given that, perhaps class_ should be dropped from all of these
+ # signatures.
+
+ def __init__(self, class_):
+ pass
+
+ def manage(self, class_, manager):
+ setattr(class_, "_default_class_manager", manager)
+
+ def unregister(self, class_, manager):
+ delattr(class_, "_default_class_manager")
+
+ def manager_getter(self, class_):
+ def get(cls):
+ return cls._default_class_manager
+
+ return get
+
+ def instrument_attribute(self, class_, key, inst):
+ pass
+
+ def post_configure_attribute(self, class_, key, inst):
+ pass
+
+ def install_descriptor(self, class_, key, inst):
+ setattr(class_, key, inst)
+
+ def uninstall_descriptor(self, class_, key):
+ delattr(class_, key)
+
+ def install_member(self, class_, key, implementation):
+ setattr(class_, key, implementation)
+
+ def uninstall_member(self, class_, key):
+ delattr(class_, key)
+
+ def instrument_collection_class(self, class_, key, collection_class):
+ return collections.prepare_instrumentation(collection_class)
+
+ def get_instance_dict(self, class_, instance):
+ return instance.__dict__
+
+ def initialize_instance_dict(self, class_, instance):
+ pass
+
+ def install_state(self, class_, instance, state):
+ setattr(instance, "_default_state", state)
+
+ def remove_state(self, class_, instance):
+ delattr(instance, "_default_state")
+
+ def state_getter(self, class_):
+ return lambda instance: getattr(instance, "_default_state")
+
+ def dict_getter(self, class_):
+ return lambda inst: self.get_instance_dict(class_, inst)
+
+
+class _ClassInstrumentationAdapter(ClassManager):
+ """Adapts a user-defined InstrumentationManager to a ClassManager."""
+
+ def __init__(self, class_, override):
+ self._adapted = override
+ self._get_state = self._adapted.state_getter(class_)
+ self._get_dict = self._adapted.dict_getter(class_)
+
+ ClassManager.__init__(self, class_)
+
+ def manage(self):
+ self._adapted.manage(self.class_, self)
+
+ def unregister(self):
+ self._adapted.unregister(self.class_, self)
+
+ def manager_getter(self):
+ return self._adapted.manager_getter(self.class_)
+
+ def instrument_attribute(self, key, inst, propagated=False):
+ ClassManager.instrument_attribute(self, key, inst, propagated)
+ if not propagated:
+ self._adapted.instrument_attribute(self.class_, key, inst)
+
+ def post_configure_attribute(self, key):
+ super(_ClassInstrumentationAdapter, self).post_configure_attribute(key)
+ self._adapted.post_configure_attribute(self.class_, key, self[key])
+
+ def install_descriptor(self, key, inst):
+ self._adapted.install_descriptor(self.class_, key, inst)
+
+ def uninstall_descriptor(self, key):
+ self._adapted.uninstall_descriptor(self.class_, key)
+
+ def install_member(self, key, implementation):
+ self._adapted.install_member(self.class_, key, implementation)
+
+ def uninstall_member(self, key):
+ self._adapted.uninstall_member(self.class_, key)
+
+ def instrument_collection_class(self, key, collection_class):
+ return self._adapted.instrument_collection_class(
+ self.class_, key, collection_class
+ )
+
+ def initialize_collection(self, key, state, factory):
+ delegate = getattr(self._adapted, "initialize_collection", None)
+ if delegate:
+ return delegate(key, state, factory)
+ else:
+ return ClassManager.initialize_collection(
+ self, key, state, factory
+ )
+
+ def new_instance(self, state=None):
+ instance = self.class_.__new__(self.class_)
+ self.setup_instance(instance, state)
+ return instance
+
+ def _new_state_if_none(self, instance):
+ """Install a default InstanceState if none is present.
+
+ A private convenience method used by the __init__ decorator.
+ """
+ if self.has_state(instance):
+ return False
+ else:
+ return self.setup_instance(instance)
+
+ def setup_instance(self, instance, state=None):
+ self._adapted.initialize_instance_dict(self.class_, instance)
+
+ if state is None:
+ state = self._state_constructor(instance, self)
+
+ # the given instance is assumed to have no state
+ self._adapted.install_state(self.class_, instance, state)
+ return state
+
+ def teardown_instance(self, instance):
+ self._adapted.remove_state(self.class_, instance)
+
+ def has_state(self, instance):
+ try:
+ self._get_state(instance)
+ except orm_exc.NO_STATE:
+ return False
+ else:
+ return True
+
+ def state_getter(self):
+ return self._get_state
+
+ def dict_getter(self):
+ return self._get_dict
+
+
+def _install_instrumented_lookups():
+ """Replace global class/object management functions
+ with ExtendedInstrumentationRegistry implementations, which
+ allow multiple types of class managers to be present,
+ at the cost of performance.
+
+ This function is called only by ExtendedInstrumentationRegistry
+ and unit tests specific to this behavior.
+
+ The _reinstall_default_lookups() function can be called
+ after this one to re-establish the default functions.
+
+ """
+ _install_lookups(
+ dict(
+ instance_state=_instrumentation_factory.state_of,
+ instance_dict=_instrumentation_factory.dict_of,
+ manager_of_class=_instrumentation_factory.manager_of_class,
+ )
+ )
+
+
+def _reinstall_default_lookups():
+ """Restore simplified lookups."""
+ _install_lookups(
+ dict(
+ instance_state=_default_state_getter,
+ instance_dict=_default_dict_getter,
+ manager_of_class=_default_manager_getter,
+ )
+ )
+ _instrumentation_factory._extended = False
+
+
+def _install_lookups(lookups):
+ global instance_state, instance_dict, manager_of_class
+ instance_state = lookups["instance_state"]
+ instance_dict = lookups["instance_dict"]
+ manager_of_class = lookups["manager_of_class"]
+ orm_base.instance_state = (
+ attributes.instance_state
+ ) = orm_instrumentation.instance_state = instance_state
+ orm_base.instance_dict = (
+ attributes.instance_dict
+ ) = orm_instrumentation.instance_dict = instance_dict
+ orm_base.manager_of_class = (
+ attributes.manager_of_class
+ ) = orm_instrumentation.manager_of_class = manager_of_class
diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py
new file mode 100644
index 0000000..cbec06a
--- /dev/null
+++ b/lib/sqlalchemy/ext/mutable.py
@@ -0,0 +1,958 @@
+# ext/mutable.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+r"""Provide support for tracking of in-place changes to scalar values,
+which are propagated into ORM change events on owning parent objects.
+
+.. _mutable_scalars:
+
+Establishing Mutability on Scalar Column Values
+===============================================
+
+A typical example of a "mutable" structure is a Python dictionary.
+Following the example introduced in :ref:`types_toplevel`, we
+begin with a custom type that marshals Python dictionaries into
+JSON strings before being persisted::
+
+ from sqlalchemy.types import TypeDecorator, VARCHAR
+ import json
+
+ class JSONEncodedDict(TypeDecorator):
+ "Represents an immutable structure as a json-encoded string."
+
+ impl = VARCHAR
+
+ def process_bind_param(self, value, dialect):
+ if value is not None:
+ value = json.dumps(value)
+ return value
+
+ def process_result_value(self, value, dialect):
+ if value is not None:
+ value = json.loads(value)
+ return value
+
+The usage of ``json`` is only for the purposes of example. The
+:mod:`sqlalchemy.ext.mutable` extension can be used
+with any type whose target Python type may be mutable, including
+:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc.
+
+When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself
+tracks all parents which reference it. Below, we illustrate a simple
+version of the :class:`.MutableDict` dictionary object, which applies
+the :class:`.Mutable` mixin to a plain Python dictionary::
+
+ from sqlalchemy.ext.mutable import Mutable
+
+ class MutableDict(Mutable, dict):
+ @classmethod
+ def coerce(cls, key, value):
+ "Convert plain dictionaries to MutableDict."
+
+ if not isinstance(value, MutableDict):
+ if isinstance(value, dict):
+ return MutableDict(value)
+
+ # this call will raise ValueError
+ return Mutable.coerce(key, value)
+ else:
+ return value
+
+ def __setitem__(self, key, value):
+ "Detect dictionary set events and emit change events."
+
+ dict.__setitem__(self, key, value)
+ self.changed()
+
+ def __delitem__(self, key):
+ "Detect dictionary del events and emit change events."
+
+ dict.__delitem__(self, key)
+ self.changed()
+
+The above dictionary class takes the approach of subclassing the Python
+built-in ``dict`` to produce a dict
+subclass which routes all mutation events through ``__setitem__``. There are
+variants on this approach, such as subclassing ``UserDict.UserDict`` or
+``collections.MutableMapping``; the part that's important to this example is
+that the :meth:`.Mutable.changed` method is called whenever an in-place
+change to the datastructure takes place.
+
+We also redefine the :meth:`.Mutable.coerce` method which will be used to
+convert any values that are not instances of ``MutableDict``, such
+as the plain dictionaries returned by the ``json`` module, into the
+appropriate type. Defining this method is optional; we could just as well
+created our ``JSONEncodedDict`` such that it always returns an instance
+of ``MutableDict``, and additionally ensured that all calling code
+uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not
+overridden, any values applied to a parent object which are not instances
+of the mutable type will raise a ``ValueError``.
+
+Our new ``MutableDict`` type offers a class method
+:meth:`~.Mutable.as_mutable` which we can use within column metadata
+to associate with types. This method grabs the given type object or
+class and associates a listener that will detect all future mappings
+of this type, applying event listening instrumentation to the mapped
+attribute. Such as, with classical table metadata::
+
+ from sqlalchemy import Table, Column, Integer
+
+ my_data = Table('my_data', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', MutableDict.as_mutable(JSONEncodedDict))
+ )
+
+Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict``
+(if the type object was not an instance already), which will intercept any
+attributes which are mapped against this type. Below we establish a simple
+mapping against the ``my_data`` table::
+
+ from sqlalchemy import mapper
+
+ class MyDataClass(object):
+ pass
+
+ # associates mutation listeners with MyDataClass.data
+ mapper(MyDataClass, my_data)
+
+The ``MyDataClass.data`` member will now be notified of in place changes
+to its value.
+
+There's no difference in usage when using declarative::
+
+ from sqlalchemy.ext.declarative import declarative_base
+
+ Base = declarative_base()
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(MutableDict.as_mutable(JSONEncodedDict))
+
+Any in-place changes to the ``MyDataClass.data`` member
+will flag the attribute as "dirty" on the parent object::
+
+ >>> from sqlalchemy.orm import Session
+
+ >>> sess = Session()
+ >>> m1 = MyDataClass(data={'value1':'foo'})
+ >>> sess.add(m1)
+ >>> sess.commit()
+
+ >>> m1.data['value1'] = 'bar'
+ >>> assert m1 in sess.dirty
+ True
+
+The ``MutableDict`` can be associated with all future instances
+of ``JSONEncodedDict`` in one step, using
+:meth:`~.Mutable.associate_with`. This is similar to
+:meth:`~.Mutable.as_mutable` except it will intercept all occurrences
+of ``MutableDict`` in all mappings unconditionally, without
+the need to declare it individually::
+
+ MutableDict.associate_with(JSONEncodedDict)
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(JSONEncodedDict)
+
+
+Supporting Pickling
+--------------------
+
+The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the
+placement of a ``weakref.WeakKeyDictionary`` upon the value object, which
+stores a mapping of parent mapped objects keyed to the attribute name under
+which they are associated with this value. ``WeakKeyDictionary`` objects are
+not picklable, due to the fact that they contain weakrefs and function
+callbacks. In our case, this is a good thing, since if this dictionary were
+picklable, it could lead to an excessively large pickle size for our value
+objects that are pickled by themselves outside of the context of the parent.
+The developer responsibility here is only to provide a ``__getstate__`` method
+that excludes the :meth:`~MutableBase._parents` collection from the pickle
+stream::
+
+ class MyMutableType(Mutable):
+ def __getstate__(self):
+ d = self.__dict__.copy()
+ d.pop('_parents', None)
+ return d
+
+With our dictionary example, we need to return the contents of the dict itself
+(and also restore them on __setstate__)::
+
+ class MutableDict(Mutable, dict):
+ # ....
+
+ def __getstate__(self):
+ return dict(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+In the case that our mutable value object is pickled as it is attached to one
+or more parent objects that are also part of the pickle, the :class:`.Mutable`
+mixin will re-establish the :attr:`.Mutable._parents` collection on each value
+object as the owning parents themselves are unpickled.
+
+Receiving Events
+----------------
+
+The :meth:`.AttributeEvents.modified` event handler may be used to receive
+an event when a mutable scalar emits a change event. This event handler
+is called when the :func:`.attributes.flag_modified` function is called
+from within the mutable extension::
+
+ from sqlalchemy.ext.declarative import declarative_base
+ from sqlalchemy import event
+
+ Base = declarative_base()
+
+ class MyDataClass(Base):
+ __tablename__ = 'my_data'
+ id = Column(Integer, primary_key=True)
+ data = Column(MutableDict.as_mutable(JSONEncodedDict))
+
+ @event.listens_for(MyDataClass.data, "modified")
+ def modified_json(instance):
+ print("json value modified:", instance.data)
+
+.. _mutable_composites:
+
+Establishing Mutability on Composites
+=====================================
+
+Composites are a special ORM feature which allow a single scalar attribute to
+be assigned an object value which represents information "composed" from one
+or more columns from the underlying mapped table. The usual example is that of
+a geometric "point", and is introduced in :ref:`mapper_composite`.
+
+As is the case with :class:`.Mutable`, the user-defined composite class
+subclasses :class:`.MutableComposite` as a mixin, and detects and delivers
+change events to its parents via the :meth:`.MutableComposite.changed` method.
+In the case of a composite class, the detection is usually via the usage of
+Python descriptors (i.e. ``@property``), or alternatively via the special
+Python method ``__setattr__()``. Below we expand upon the ``Point`` class
+introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite`
+and to also route attribute set events via ``__setattr__`` to the
+:meth:`.MutableComposite.changed` method::
+
+ from sqlalchemy.ext.mutable import MutableComposite
+
+ class Point(MutableComposite):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __setattr__(self, key, value):
+ "Intercept set events"
+
+ # set the attribute
+ object.__setattr__(self, key, value)
+
+ # alert all parents to the change
+ self.changed()
+
+ def __composite_values__(self):
+ return self.x, self.y
+
+ def __eq__(self, other):
+ return isinstance(other, Point) and \
+ other.x == self.x and \
+ other.y == self.y
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+The :class:`.MutableComposite` class uses a Python metaclass to automatically
+establish listeners for any usage of :func:`_orm.composite` that specifies our
+``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class,
+listeners are established which will route change events from ``Point``
+objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes::
+
+ from sqlalchemy.orm import composite, mapper
+ from sqlalchemy import Table, Column
+
+ vertices = Table('vertices', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('x1', Integer),
+ Column('y1', Integer),
+ Column('x2', Integer),
+ Column('y2', Integer),
+ )
+
+ class Vertex(object):
+ pass
+
+ mapper(Vertex, vertices, properties={
+ 'start': composite(Point, vertices.c.x1, vertices.c.y1),
+ 'end': composite(Point, vertices.c.x2, vertices.c.y2)
+ })
+
+Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members
+will flag the attribute as "dirty" on the parent object::
+
+ >>> from sqlalchemy.orm import Session
+
+ >>> sess = Session()
+ >>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15))
+ >>> sess.add(v1)
+ >>> sess.commit()
+
+ >>> v1.end.x = 8
+ >>> assert v1 in sess.dirty
+ True
+
+Coercing Mutable Composites
+---------------------------
+
+The :meth:`.MutableBase.coerce` method is also supported on composite types.
+In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce`
+method is only called for attribute set operations, not load operations.
+Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent
+to using a :func:`.validates` validation routine for all attributes which
+make use of the custom composite type::
+
+ class Point(MutableComposite):
+ # other Point methods
+ # ...
+
+ def coerce(cls, key, value):
+ if isinstance(value, tuple):
+ value = Point(*value)
+ elif not isinstance(value, Point):
+ raise ValueError("tuple or Point expected")
+ return value
+
+Supporting Pickling
+--------------------
+
+As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper
+class uses a ``weakref.WeakKeyDictionary`` available via the
+:meth:`MutableBase._parents` attribute which isn't picklable. If we need to
+pickle instances of ``Point`` or its owning class ``Vertex``, we at least need
+to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary.
+Below we define both a ``__getstate__`` and a ``__setstate__`` that package up
+the minimal form of our ``Point`` class::
+
+ class Point(MutableComposite):
+ # ...
+
+ def __getstate__(self):
+ return self.x, self.y
+
+ def __setstate__(self, state):
+ self.x, self.y = state
+
+As with :class:`.Mutable`, the :class:`.MutableComposite` augments the
+pickling process of the parent's object-relational state so that the
+:meth:`MutableBase._parents` collection is restored to all ``Point`` objects.
+
+"""
+from collections import defaultdict
+import weakref
+
+from .. import event
+from .. import inspect
+from .. import types
+from ..orm import Mapper
+from ..orm import mapper
+from ..orm.attributes import flag_modified
+from ..sql.base import SchemaEventTarget
+from ..util import memoized_property
+
+
+class MutableBase(object):
+ """Common base class to :class:`.Mutable`
+ and :class:`.MutableComposite`.
+
+ """
+
+ @memoized_property
+ def _parents(self):
+ """Dictionary of parent object's :class:`.InstanceState`->attribute
+ name on the parent.
+
+ This attribute is a so-called "memoized" property. It initializes
+ itself with a new ``weakref.WeakKeyDictionary`` the first time
+ it is accessed, returning the same object upon subsequent access.
+
+ .. versionchanged:: 1.4 the :class:`.InstanceState` is now used
+ as the key in the weak dictionary rather than the instance
+ itself.
+
+ """
+
+ return weakref.WeakKeyDictionary()
+
+ @classmethod
+ def coerce(cls, key, value):
+ """Given a value, coerce it into the target type.
+
+ Can be overridden by custom subclasses to coerce incoming
+ data into a particular type.
+
+ By default, raises ``ValueError``.
+
+ This method is called in different scenarios depending on if
+ the parent class is of type :class:`.Mutable` or of type
+ :class:`.MutableComposite`. In the case of the former, it is called
+ for both attribute-set operations as well as during ORM loading
+ operations. For the latter, it is only called during attribute-set
+ operations; the mechanics of the :func:`.composite` construct
+ handle coercion during load operations.
+
+
+ :param key: string name of the ORM-mapped attribute being set.
+ :param value: the incoming value.
+ :return: the method should return the coerced value, or raise
+ ``ValueError`` if the coercion cannot be completed.
+
+ """
+ if value is None:
+ return None
+ msg = "Attribute '%s' does not accept objects of type %s"
+ raise ValueError(msg % (key, type(value)))
+
+ @classmethod
+ def _get_listen_keys(cls, attribute):
+ """Given a descriptor attribute, return a ``set()`` of the attribute
+ keys which indicate a change in the state of this attribute.
+
+ This is normally just ``set([attribute.key])``, but can be overridden
+ to provide for additional keys. E.g. a :class:`.MutableComposite`
+ augments this set with the attribute keys associated with the columns
+ that comprise the composite value.
+
+ This collection is consulted in the case of intercepting the
+ :meth:`.InstanceEvents.refresh` and
+ :meth:`.InstanceEvents.refresh_flush` events, which pass along a list
+ of attribute names that have been refreshed; the list is compared
+ against this set to determine if action needs to be taken.
+
+ .. versionadded:: 1.0.5
+
+ """
+ return {attribute.key}
+
+ @classmethod
+ def _listen_on_attribute(cls, attribute, coerce, parent_cls):
+ """Establish this type as a mutation listener for the given
+ mapped descriptor.
+
+ """
+ key = attribute.key
+ if parent_cls is not attribute.class_:
+ return
+
+ # rely on "propagate" here
+ parent_cls = attribute.class_
+
+ listen_keys = cls._get_listen_keys(attribute)
+
+ def load(state, *args):
+ """Listen for objects loaded or refreshed.
+
+ Wrap the target data member's value with
+ ``Mutable``.
+
+ """
+ val = state.dict.get(key, None)
+ if val is not None:
+ if coerce:
+ val = cls.coerce(key, val)
+ state.dict[key] = val
+ val._parents[state] = key
+
+ def load_attrs(state, ctx, attrs):
+ if not attrs or listen_keys.intersection(attrs):
+ load(state)
+
+ def set_(target, value, oldvalue, initiator):
+ """Listen for set/replace events on the target
+ data member.
+
+ Establish a weak reference to the parent object
+ on the incoming value, remove it for the one
+ outgoing.
+
+ """
+ if value is oldvalue:
+ return value
+
+ if not isinstance(value, cls):
+ value = cls.coerce(key, value)
+ if value is not None:
+ value._parents[target] = key
+ if isinstance(oldvalue, cls):
+ oldvalue._parents.pop(inspect(target), None)
+ return value
+
+ def pickle(state, state_dict):
+ val = state.dict.get(key, None)
+ if val is not None:
+ if "ext.mutable.values" not in state_dict:
+ state_dict["ext.mutable.values"] = defaultdict(list)
+ state_dict["ext.mutable.values"][key].append(val)
+
+ def unpickle(state, state_dict):
+ if "ext.mutable.values" in state_dict:
+ collection = state_dict["ext.mutable.values"]
+ if isinstance(collection, list):
+ # legacy format
+ for val in collection:
+ val._parents[state] = key
+ else:
+ for val in state_dict["ext.mutable.values"][key]:
+ val._parents[state] = key
+
+ event.listen(parent_cls, "load", load, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "refresh", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True
+ )
+ event.listen(
+ attribute, "set", set_, raw=True, retval=True, propagate=True
+ )
+ event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True)
+ event.listen(
+ parent_cls, "unpickle", unpickle, raw=True, propagate=True
+ )
+
+
+class Mutable(MutableBase):
+ """Mixin that defines transparent propagation of change
+ events to a parent object.
+
+ See the example in :ref:`mutable_scalars` for usage information.
+
+ """
+
+ def changed(self):
+ """Subclasses should call this method whenever change events occur."""
+
+ for parent, key in self._parents.items():
+ flag_modified(parent.obj(), key)
+
+ @classmethod
+ def associate_with_attribute(cls, attribute):
+ """Establish this type as a mutation listener for the given
+ mapped descriptor.
+
+ """
+ cls._listen_on_attribute(attribute, True, attribute.class_)
+
+ @classmethod
+ def associate_with(cls, sqltype):
+ """Associate this wrapper with all future mapped columns
+ of the given type.
+
+ This is a convenience method that calls
+ ``associate_with_attribute`` automatically.
+
+ .. warning::
+
+ The listeners established by this method are *global*
+ to all mappers, and are *not* garbage collected. Only use
+ :meth:`.associate_with` for types that are permanent to an
+ application, not with ad-hoc types else this will cause unbounded
+ growth in memory usage.
+
+ """
+
+ def listen_for_type(mapper, class_):
+ if mapper.non_primary:
+ return
+ for prop in mapper.column_attrs:
+ if isinstance(prop.columns[0].type, sqltype):
+ cls.associate_with_attribute(getattr(class_, prop.key))
+
+ event.listen(mapper, "mapper_configured", listen_for_type)
+
+ @classmethod
+ def as_mutable(cls, sqltype):
+ """Associate a SQL type with this mutable Python type.
+
+ This establishes listeners that will detect ORM mappings against
+ the given type, adding mutation event trackers to those mappings.
+
+ The type is returned, unconditionally as an instance, so that
+ :meth:`.as_mutable` can be used inline::
+
+ Table('mytable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', MyMutableType.as_mutable(PickleType))
+ )
+
+ Note that the returned type is always an instance, even if a class
+ is given, and that only columns which are declared specifically with
+ that type instance receive additional instrumentation.
+
+ To associate a particular mutable type with all occurrences of a
+ particular type, use the :meth:`.Mutable.associate_with` classmethod
+ of the particular :class:`.Mutable` subclass to establish a global
+ association.
+
+ .. warning::
+
+ The listeners established by this method are *global*
+ to all mappers, and are *not* garbage collected. Only use
+ :meth:`.as_mutable` for types that are permanent to an application,
+ not with ad-hoc types else this will cause unbounded growth
+ in memory usage.
+
+ """
+ sqltype = types.to_instance(sqltype)
+
+ # a SchemaType will be copied when the Column is copied,
+ # and we'll lose our ability to link that type back to the original.
+ # so track our original type w/ columns
+ if isinstance(sqltype, SchemaEventTarget):
+
+ @event.listens_for(sqltype, "before_parent_attach")
+ def _add_column_memo(sqltyp, parent):
+ parent.info["_ext_mutable_orig_type"] = sqltyp
+
+ schema_event_check = True
+ else:
+ schema_event_check = False
+
+ def listen_for_type(mapper, class_):
+ if mapper.non_primary:
+ return
+ for prop in mapper.column_attrs:
+ if (
+ schema_event_check
+ and hasattr(prop.expression, "info")
+ and prop.expression.info.get("_ext_mutable_orig_type")
+ is sqltype
+ ) or (prop.columns[0].type is sqltype):
+ cls.associate_with_attribute(getattr(class_, prop.key))
+
+ event.listen(mapper, "mapper_configured", listen_for_type)
+
+ return sqltype
+
+
+class MutableComposite(MutableBase):
+ """Mixin that defines transparent propagation of change
+ events on a SQLAlchemy "composite" object to its
+ owning parent or parents.
+
+ See the example in :ref:`mutable_composites` for usage information.
+
+ """
+
+ @classmethod
+ def _get_listen_keys(cls, attribute):
+ return {attribute.key}.union(attribute.property._attribute_keys)
+
+ def changed(self):
+ """Subclasses should call this method whenever change events occur."""
+
+ for parent, key in self._parents.items():
+
+ prop = parent.mapper.get_property(key)
+ for value, attr_name in zip(
+ self.__composite_values__(), prop._attribute_keys
+ ):
+ setattr(parent.obj(), attr_name, value)
+
+
+def _setup_composite_listener():
+ def _listen_for_type(mapper, class_):
+ for prop in mapper.iterate_properties:
+ if (
+ hasattr(prop, "composite_class")
+ and isinstance(prop.composite_class, type)
+ and issubclass(prop.composite_class, MutableComposite)
+ ):
+ prop.composite_class._listen_on_attribute(
+ getattr(class_, prop.key), False, class_
+ )
+
+ if not event.contains(Mapper, "mapper_configured", _listen_for_type):
+ event.listen(Mapper, "mapper_configured", _listen_for_type)
+
+
+_setup_composite_listener()
+
+
+class MutableDict(Mutable, dict):
+ """A dictionary type that implements :class:`.Mutable`.
+
+ The :class:`.MutableDict` object implements a dictionary that will
+ emit change events to the underlying mapping when the contents of
+ the dictionary are altered, including when values are added or removed.
+
+ Note that :class:`.MutableDict` does **not** apply mutable tracking to the
+ *values themselves* inside the dictionary. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ dictionary structure, such as a JSON structure. To support this use case,
+ build a subclass of :class:`.MutableDict` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. seealso::
+
+ :class:`.MutableList`
+
+ :class:`.MutableSet`
+
+ """
+
+ def __setitem__(self, key, value):
+ """Detect dictionary set events and emit change events."""
+ dict.__setitem__(self, key, value)
+ self.changed()
+
+ def setdefault(self, key, value):
+ result = dict.setdefault(self, key, value)
+ self.changed()
+ return result
+
+ def __delitem__(self, key):
+ """Detect dictionary del events and emit change events."""
+ dict.__delitem__(self, key)
+ self.changed()
+
+ def update(self, *a, **kw):
+ dict.update(self, *a, **kw)
+ self.changed()
+
+ def pop(self, *arg):
+ result = dict.pop(self, *arg)
+ self.changed()
+ return result
+
+ def popitem(self):
+ result = dict.popitem(self)
+ self.changed()
+ return result
+
+ def clear(self):
+ dict.clear(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, key, value):
+ """Convert plain dictionary to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, dict):
+ return cls(value)
+ return Mutable.coerce(key, value)
+ else:
+ return value
+
+ def __getstate__(self):
+ return dict(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+
+class MutableList(Mutable, list):
+ """A list type that implements :class:`.Mutable`.
+
+ The :class:`.MutableList` object implements a list that will
+ emit change events to the underlying mapping when the contents of
+ the list are altered, including when values are added or removed.
+
+ Note that :class:`.MutableList` does **not** apply mutable tracking to the
+ *values themselves* inside the list. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ mutable structure, such as a JSON structure. To support this use case,
+ build a subclass of :class:`.MutableList` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :class:`.MutableDict`
+
+ :class:`.MutableSet`
+
+ """
+
+ def __reduce_ex__(self, proto):
+ return (self.__class__, (list(self),))
+
+ # needed for backwards compatibility with
+ # older pickles
+ def __setstate__(self, state):
+ self[:] = state
+
+ def __setitem__(self, index, value):
+ """Detect list set events and emit change events."""
+ list.__setitem__(self, index, value)
+ self.changed()
+
+ def __setslice__(self, start, end, value):
+ """Detect list set events and emit change events."""
+ list.__setslice__(self, start, end, value)
+ self.changed()
+
+ def __delitem__(self, index):
+ """Detect list del events and emit change events."""
+ list.__delitem__(self, index)
+ self.changed()
+
+ def __delslice__(self, start, end):
+ """Detect list del events and emit change events."""
+ list.__delslice__(self, start, end)
+ self.changed()
+
+ def pop(self, *arg):
+ result = list.pop(self, *arg)
+ self.changed()
+ return result
+
+ def append(self, x):
+ list.append(self, x)
+ self.changed()
+
+ def extend(self, x):
+ list.extend(self, x)
+ self.changed()
+
+ def __iadd__(self, x):
+ self.extend(x)
+ return self
+
+ def insert(self, i, x):
+ list.insert(self, i, x)
+ self.changed()
+
+ def remove(self, i):
+ list.remove(self, i)
+ self.changed()
+
+ def clear(self):
+ list.clear(self)
+ self.changed()
+
+ def sort(self, **kw):
+ list.sort(self, **kw)
+ self.changed()
+
+ def reverse(self):
+ list.reverse(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, index, value):
+ """Convert plain list to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, list):
+ return cls(value)
+ return Mutable.coerce(index, value)
+ else:
+ return value
+
+
+class MutableSet(Mutable, set):
+ """A set type that implements :class:`.Mutable`.
+
+ The :class:`.MutableSet` object implements a set that will
+ emit change events to the underlying mapping when the contents of
+ the set are altered, including when values are added or removed.
+
+ Note that :class:`.MutableSet` does **not** apply mutable tracking to the
+ *values themselves* inside the set. Therefore it is not a sufficient
+ solution for the use case of tracking deep changes to a *recursive*
+ mutable structure. To support this use case,
+ build a subclass of :class:`.MutableSet` that provides appropriate
+ coercion to the values placed in the dictionary so that they too are
+ "mutable", and emit events up to their parent structure.
+
+ .. versionadded:: 1.1
+
+ .. seealso::
+
+ :class:`.MutableDict`
+
+ :class:`.MutableList`
+
+
+ """
+
+ def update(self, *arg):
+ set.update(self, *arg)
+ self.changed()
+
+ def intersection_update(self, *arg):
+ set.intersection_update(self, *arg)
+ self.changed()
+
+ def difference_update(self, *arg):
+ set.difference_update(self, *arg)
+ self.changed()
+
+ def symmetric_difference_update(self, *arg):
+ set.symmetric_difference_update(self, *arg)
+ self.changed()
+
+ def __ior__(self, other):
+ self.update(other)
+ return self
+
+ def __iand__(self, other):
+ self.intersection_update(other)
+ return self
+
+ def __ixor__(self, other):
+ self.symmetric_difference_update(other)
+ return self
+
+ def __isub__(self, other):
+ self.difference_update(other)
+ return self
+
+ def add(self, elem):
+ set.add(self, elem)
+ self.changed()
+
+ def remove(self, elem):
+ set.remove(self, elem)
+ self.changed()
+
+ def discard(self, elem):
+ set.discard(self, elem)
+ self.changed()
+
+ def pop(self, *arg):
+ result = set.pop(self, *arg)
+ self.changed()
+ return result
+
+ def clear(self):
+ set.clear(self)
+ self.changed()
+
+ @classmethod
+ def coerce(cls, index, value):
+ """Convert plain set to instance of this class."""
+ if not isinstance(value, cls):
+ if isinstance(value, set):
+ return cls(value)
+ return Mutable.coerce(index, value)
+ else:
+ return value
+
+ def __getstate__(self):
+ return set(self)
+
+ def __setstate__(self, state):
+ self.update(state)
+
+ def __reduce_ex__(self, proto):
+ return (self.__class__, (list(self),))
diff --git a/lib/sqlalchemy/ext/mypy/__init__.py b/lib/sqlalchemy/ext/mypy/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/__init__.py
diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py
new file mode 100644
index 0000000..99be194
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/apply.py
@@ -0,0 +1,299 @@
+# ext/mypy/apply.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mypy.nodes import ARG_NAMED_OPT
+from mypy.nodes import Argument
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import MDEF
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import add_method_to_class
+from mypy.types import AnyType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneTyp
+from mypy.types import ProperType
+from mypy.types import TypeOfAny
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+from . import infer
+from . import util
+from .names import NAMED_TYPE_SQLA_MAPPED
+
+
+def apply_mypy_mapped_attr(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ item: Union[NameExpr, StrExpr],
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ if isinstance(item, NameExpr):
+ name = item.name
+ elif isinstance(item, StrExpr):
+ name = item.value
+ else:
+ return None
+
+ for stmt in cls.defs.body:
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name == name
+ ):
+ break
+ else:
+ util.fail(api, "Can't find mapped attribute {}".format(name), cls)
+ return None
+
+ if stmt.type is None:
+ util.fail(
+ api,
+ "Statement linked from _mypy_mapped_attrs has no "
+ "typing information",
+ stmt,
+ )
+ return None
+
+ left_hand_explicit_type = get_proper_type(stmt.type)
+ assert isinstance(
+ left_hand_explicit_type, (Instance, UnionType, UnboundType)
+ )
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=item.line,
+ column=item.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+
+ apply_type_to_mapped_statement(
+ api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
+ )
+
+
+def re_apply_declarative_assignments(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """For multiple class passes, re-apply our left-hand side types as mypy
+ seems to reset them in place.
+
+ """
+ mapped_attr_lookup = {attr.name: attr for attr in attributes}
+ update_cls_metadata = False
+
+ for stmt in cls.defs.body:
+ # for a re-apply, all of our statements are AssignmentStmt;
+ # @declared_attr calls will have been converted and this
+ # currently seems to be preserved by mypy (but who knows if this
+ # will change).
+ if (
+ isinstance(stmt, AssignmentStmt)
+ and isinstance(stmt.lvalues[0], NameExpr)
+ and stmt.lvalues[0].name in mapped_attr_lookup
+ and isinstance(stmt.lvalues[0].node, Var)
+ ):
+
+ left_node = stmt.lvalues[0].node
+ python_type_for_type = mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type
+
+ left_node_proper_type = get_proper_type(left_node.type)
+
+ # if we have scanned an UnboundType and now there's a more
+ # specific type than UnboundType, call the re-scan so we
+ # can get that set up correctly
+ if (
+ isinstance(python_type_for_type, UnboundType)
+ and not isinstance(left_node_proper_type, UnboundType)
+ and (
+ isinstance(stmt.rvalue, CallExpr)
+ and isinstance(stmt.rvalue.callee, MemberExpr)
+ and isinstance(stmt.rvalue.callee.expr, NameExpr)
+ and stmt.rvalue.callee.expr.node is not None
+ and stmt.rvalue.callee.expr.node.fullname
+ == NAMED_TYPE_SQLA_MAPPED
+ and stmt.rvalue.callee.name == "_empty_constructor"
+ and isinstance(stmt.rvalue.args[0], CallExpr)
+ and isinstance(stmt.rvalue.args[0].callee, RefExpr)
+ )
+ ):
+
+ python_type_for_type = (
+ infer.infer_type_from_right_hand_nameexpr(
+ api,
+ stmt,
+ left_node,
+ left_node_proper_type,
+ stmt.rvalue.args[0].callee,
+ )
+ )
+
+ if python_type_for_type is None or isinstance(
+ python_type_for_type, UnboundType
+ ):
+ continue
+
+ # update the SQLAlchemyAttribute with the better information
+ mapped_attr_lookup[
+ stmt.lvalues[0].name
+ ].type = python_type_for_type
+
+ update_cls_metadata = True
+
+ if python_type_for_type is not None:
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
+ )
+
+ if update_cls_metadata:
+ util.set_mapped_attributes(cls.info, attributes)
+
+
+def apply_type_to_mapped_statement(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ lvalue: NameExpr,
+ left_hand_explicit_type: Optional[ProperType],
+ python_type_for_type: Optional[ProperType],
+) -> None:
+ """Apply the Mapped[<type>] annotation and right hand object to a
+ declarative assignment statement.
+
+ This converts a Python declarative class statement such as::
+
+ class User(Base):
+ # ...
+
+ attrname = Column(Integer)
+
+ To one that describes the final Python behavior to Mypy::
+
+ class User(Base):
+ # ...
+
+ attrname : Mapped[Optional[int]] = <meaningless temp node>
+
+ """
+ left_node = lvalue.node
+ assert isinstance(left_node, Var)
+
+ if left_hand_explicit_type is not None:
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
+ )
+ else:
+ lvalue.is_inferred_def = False
+ left_node.type = api.named_type(
+ NAMED_TYPE_SQLA_MAPPED,
+ [] if python_type_for_type is None else [python_type_for_type],
+ )
+
+ # so to have it skip the right side totally, we can do this:
+ # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # however, if we instead manufacture a new node that uses the old
+ # one, then we can still get type checking for the call itself,
+ # e.g. the Column, relationship() call, etc.
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
+ # the original right-hand side is maintained so it gets type checked
+ # internally
+ stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
+
+
+def add_additional_orm_attributes(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Apply __init__, __table__ and other attributes to the mapped class."""
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ return
+
+ is_base = util.get_is_base(info)
+
+ if "__init__" not in info.names and not is_base:
+ mapped_attr_names = {attr.name: attr.type for attr in attributes}
+
+ for base in info.mro[1:-1]:
+ if "sqlalchemy" not in info.metadata:
+ continue
+
+ base_cls_attributes = util.get_mapped_attributes(base, api)
+ if base_cls_attributes is None:
+ continue
+
+ for attr in base_cls_attributes:
+ mapped_attr_names.setdefault(attr.name, attr.type)
+
+ arguments = []
+ for name, typ in mapped_attr_names.items():
+ if typ is None:
+ typ = AnyType(TypeOfAny.special_form)
+ arguments.append(
+ Argument(
+ variable=Var(name, typ),
+ type_annotation=typ,
+ initializer=TempNode(typ),
+ kind=ARG_NAMED_OPT,
+ )
+ )
+
+ add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
+
+ if "__table__" not in info.names and util.get_has_table(info):
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.sql.schema.Table", "__table__"
+ )
+ if not is_base:
+ _apply_placeholder_attr_to_class(
+ api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
+ )
+
+
+def _apply_placeholder_attr_to_class(
+ api: SemanticAnalyzerPluginInterface,
+ cls: ClassDef,
+ qualified_name: str,
+ attrname: str,
+) -> None:
+ sym = api.lookup_fully_qualified_or_none(qualified_name)
+ if sym:
+ assert isinstance(sym.node, TypeInfo)
+ type_: ProperType = Instance(sym.node, [])
+ else:
+ type_ = AnyType(TypeOfAny.special_form)
+ var = Var(attrname)
+ var._fullname = cls.fullname + "." + attrname
+ var.info = cls.info
+ var.type = type_
+ cls.info.names[attrname] = SymbolTableNode(MDEF, var)
diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py
new file mode 100644
index 0000000..c33c30e
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/decl_class.py
@@ -0,0 +1,516 @@
+# ext/mypy/decl_class.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import List
+from typing import Optional
+from typing import Union
+
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import Decorator
+from mypy.nodes import LambdaExpr
+from mypy.nodes import ListExpr
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import PlaceholderNode
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import SymbolNode
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TempNode
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import Type
+from mypy.types import TypeOfAny
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+from . import apply
+from . import infer
+from . import names
+from . import util
+
+
+def scan_declarative_assignments_and_apply_types(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ is_mixin_scan: bool = False,
+) -> Optional[List[util.SQLAlchemyAttribute]]:
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ # this can occur during cached passes
+ return None
+ elif cls.fullname.startswith("builtins"):
+ return None
+
+ mapped_attributes: Optional[
+ List[util.SQLAlchemyAttribute]
+ ] = util.get_mapped_attributes(info, api)
+
+ # used by assign.add_additional_orm_attributes among others
+ util.establish_as_sqlalchemy(info)
+
+ if mapped_attributes is not None:
+ # ensure that a class that's mapped is always picked up by
+ # its mapped() decorator or declarative metaclass before
+ # it would be detected as an unmapped mixin class
+
+ if not is_mixin_scan:
+ # mypy can call us more than once. it then *may* have reset the
+ # left hand side of everything, but not the right that we removed,
+ # removing our ability to re-scan. but we have the types
+ # here, so lets re-apply them, or if we have an UnboundType,
+ # we can re-scan
+
+ apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
+
+ return mapped_attributes
+
+ mapped_attributes = []
+
+ if not cls.defs.body:
+ # when we get a mixin class from another file, the body is
+ # empty (!) but the names are in the symbol table. so use that.
+
+ for sym_name, sym in info.names.items():
+ _scan_symbol_table_entry(
+ cls, api, sym_name, sym, mapped_attributes
+ )
+ else:
+ for stmt in util.flatten_typechecking(cls.defs.body):
+ if isinstance(stmt, AssignmentStmt):
+ _scan_declarative_assignment_stmt(
+ cls, api, stmt, mapped_attributes
+ )
+ elif isinstance(stmt, Decorator):
+ _scan_declarative_decorator_stmt(
+ cls, api, stmt, mapped_attributes
+ )
+ _scan_for_mapped_bases(cls, api)
+
+ if not is_mixin_scan:
+ apply.add_additional_orm_attributes(cls, api, mapped_attributes)
+
+ util.set_mapped_attributes(info, mapped_attributes)
+
+ return mapped_attributes
+
+
+def _scan_symbol_table_entry(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ name: str,
+ value: SymbolTableNode,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from a SymbolTableNode that's in the
+ type.names dictionary.
+
+ """
+ value_type = get_proper_type(value.type)
+ if not isinstance(value_type, Instance):
+ return
+
+ left_hand_explicit_type = None
+ type_id = names.type_id_for_named_node(value_type.type)
+ # type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
+
+ err = False
+
+ # TODO: this is nearly the same logic as that of
+ # _scan_declarative_decorator_stmt, likely can be merged
+ if type_id in {
+ names.MAPPED,
+ names.RELATIONSHIP,
+ names.COMPOSITE_PROPERTY,
+ names.MAPPER_PROPERTY,
+ names.SYNONYM_PROPERTY,
+ names.COLUMN_PROPERTY,
+ }:
+ if value_type.args:
+ left_hand_explicit_type = get_proper_type(value_type.args[0])
+ else:
+ err = True
+ elif type_id is names.COLUMN:
+ if not value_type.args:
+ err = True
+ else:
+ typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
+ value_type.args[0]
+ )
+ if isinstance(typeengine_arg, Instance):
+ typeengine_arg = typeengine_arg.type
+
+ if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
+ sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
+
+ left_hand_explicit_type = UnionType(
+ [
+ infer.extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ value_type,
+ )
+
+ if err:
+ msg = (
+ "Can't infer type from attribute {} on class {}. "
+ "please specify a return type from this function that is "
+ "one of: Mapped[<python type>], relationship[<target class>], "
+ "Column[<TypeEngine>], MapperProperty[<python type>]"
+ )
+ util.fail(api, msg.format(name, cls.name), cls)
+
+ left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+ if left_hand_explicit_type is not None:
+ assert value.node is not None
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=name,
+ line=value.node.line,
+ column=value.node.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+
+
+def _scan_declarative_decorator_stmt(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ stmt: Decorator,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from a @declared_attr in a declarative
+ class.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ @declared_attr
+ def updated_at(cls) -> Column[DateTime]:
+ return Column(DateTime)
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ updated_at: Mapped[Optional[datetime.datetime]]
+
+ """
+ for dec in stmt.decorators:
+ if (
+ isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
+ and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
+ ):
+ break
+ else:
+ return
+
+ dec_index = cls.defs.body.index(stmt)
+
+ left_hand_explicit_type: Optional[ProperType] = None
+
+ if util.name_is_dunder(stmt.name):
+ # for dunder names like __table_args__, __tablename__,
+ # __mapper_args__ etc., rewrite these as simple assignment
+ # statements; otherwise mypy doesn't like if the decorated
+ # function has an annotation like ``cls: Type[Foo]`` because
+ # it isn't @classmethod
+ any_ = AnyType(TypeOfAny.special_form)
+ left_node = NameExpr(stmt.var.name)
+ left_node.node = stmt.var
+ new_stmt = AssignmentStmt([left_node], TempNode(any_))
+ new_stmt.type = left_node.node.type
+ cls.defs.body[dec_index] = new_stmt
+ return
+ elif isinstance(stmt.func.type, CallableType):
+ func_type = stmt.func.type.ret_type
+ if isinstance(func_type, UnboundType):
+ type_id = names.type_id_for_unbound_type(func_type, cls, api)
+ else:
+ # this does not seem to occur unless the type argument is
+ # incorrect
+ return
+
+ if (
+ type_id
+ in {
+ names.MAPPED,
+ names.RELATIONSHIP,
+ names.COMPOSITE_PROPERTY,
+ names.MAPPER_PROPERTY,
+ names.SYNONYM_PROPERTY,
+ names.COLUMN_PROPERTY,
+ }
+ and func_type.args
+ ):
+ left_hand_explicit_type = get_proper_type(func_type.args[0])
+ elif type_id is names.COLUMN and func_type.args:
+ typeengine_arg = func_type.args[0]
+ if isinstance(typeengine_arg, UnboundType):
+ sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
+ if sym is not None and isinstance(sym.node, TypeInfo):
+ if names.has_base_type_id(sym.node, names.TYPEENGINE):
+ left_hand_explicit_type = UnionType(
+ [
+ infer.extract_python_type_from_typeengine(
+ api, sym.node, []
+ ),
+ NoneType(),
+ ]
+ )
+ else:
+ util.fail(
+ api,
+ "Column type should be a TypeEngine "
+ "subclass not '{}'".format(sym.node.fullname),
+ func_type,
+ )
+
+ if left_hand_explicit_type is None:
+ # no type on the decorated function. our option here is to
+ # dig into the function body and get the return type, but they
+ # should just have an annotation.
+ msg = (
+ "Can't infer type from @declared_attr on function '{}'; "
+ "please specify a return type from this function that is "
+ "one of: Mapped[<python type>], relationship[<target class>], "
+ "Column[<TypeEngine>], MapperProperty[<python type>]"
+ )
+ util.fail(api, msg.format(stmt.var.name), stmt)
+
+ left_hand_explicit_type = AnyType(TypeOfAny.special_form)
+
+ left_node = NameExpr(stmt.var.name)
+ left_node.node = stmt.var
+
+ # totally feeling around in the dark here as I don't totally understand
+ # the significance of UnboundType. It seems to be something that is
+ # not going to do what's expected when it is applied as the type of
+ # an AssignmentStatement. So do a feeling-around-in-the-dark version
+ # of converting it to the regular Instance/TypeInfo/UnionType structures
+ # we see everywhere else.
+ if isinstance(left_hand_explicit_type, UnboundType):
+ left_hand_explicit_type = get_proper_type(
+ util.unbound_to_instance(api, left_hand_explicit_type)
+ )
+
+ left_node.node.type = api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
+ )
+
+ # this will ignore the rvalue entirely
+ # rvalue = TempNode(AnyType(TypeOfAny.special_form))
+
+ # rewrite the node as:
+ # <attr> : Mapped[<typ>] =
+ # _sa_Mapped._empty_constructor(lambda: <function body>)
+ # the function body is maintained so it gets type checked internally
+ rvalue = util.expr_to_mapped_constructor(
+ LambdaExpr(stmt.func.arguments, stmt.func.body)
+ )
+
+ new_stmt = AssignmentStmt([left_node], rvalue)
+ new_stmt.type = left_node.node.type
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=left_node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=left_hand_explicit_type,
+ info=cls.info,
+ )
+ )
+ cls.defs.body[dec_index] = new_stmt
+
+
+def _scan_declarative_assignment_stmt(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ attributes: List[util.SQLAlchemyAttribute],
+) -> None:
+ """Extract mapping information from an assignment statement in a
+ declarative class.
+
+ """
+ lvalue = stmt.lvalues[0]
+ if not isinstance(lvalue, NameExpr):
+ return
+
+ sym = cls.info.names.get(lvalue.name)
+
+ # this establishes that semantic analysis has taken place, which
+ # means the nodes are populated and we are called from an appropriate
+ # hook.
+ assert sym is not None
+ node = sym.node
+
+ if isinstance(node, PlaceholderNode):
+ return
+
+ assert node is lvalue.node
+ assert isinstance(node, Var)
+
+ if node.name == "__abstract__":
+ if api.parse_bool(stmt.rvalue) is True:
+ util.set_is_base(cls.info)
+ return
+ elif node.name == "__tablename__":
+ util.set_has_table(cls.info)
+ elif node.name.startswith("__"):
+ return
+ elif node.name == "_mypy_mapped_attrs":
+ if not isinstance(stmt.rvalue, ListExpr):
+ util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
+ else:
+ for item in stmt.rvalue.items:
+ if isinstance(item, (NameExpr, StrExpr)):
+ apply.apply_mypy_mapped_attr(cls, api, item, attributes)
+
+ left_hand_mapped_type: Optional[Type] = None
+ left_hand_explicit_type: Optional[ProperType] = None
+
+ if node.is_inferred or node.type is None:
+ if isinstance(stmt.type, UnboundType):
+ # look for an explicit Mapped[] type annotation on the left
+ # side with nothing on the right
+
+ # print(stmt.type)
+ # Mapped?[Optional?[A?]]
+
+ left_hand_explicit_type = stmt.type
+
+ if stmt.type.name == "Mapped":
+ mapped_sym = api.lookup_qualified("Mapped", cls)
+ if (
+ mapped_sym is not None
+ and mapped_sym.node is not None
+ and names.type_id_for_named_node(mapped_sym.node)
+ is names.MAPPED
+ ):
+ left_hand_explicit_type = get_proper_type(
+ stmt.type.args[0]
+ )
+ left_hand_mapped_type = stmt.type
+
+ # TODO: do we need to convert from unbound for this case?
+ # left_hand_explicit_type = util._unbound_to_instance(
+ # api, left_hand_explicit_type
+ # )
+ else:
+ node_type = get_proper_type(node.type)
+ if (
+ isinstance(node_type, Instance)
+ and names.type_id_for_named_node(node_type.type) is names.MAPPED
+ ):
+ # print(node.type)
+ # sqlalchemy.orm.attributes.Mapped[<python type>]
+ left_hand_explicit_type = get_proper_type(node_type.args[0])
+ left_hand_mapped_type = node_type
+ else:
+ # print(node.type)
+ # <python type>
+ left_hand_explicit_type = node_type
+ left_hand_mapped_type = None
+
+ if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
+ # annotation without assignment and Mapped is present
+ # as type annotation
+ # equivalent to using _infer_type_from_left_hand_type_only.
+
+ python_type_for_type = left_hand_explicit_type
+ elif isinstance(stmt.rvalue, CallExpr) and isinstance(
+ stmt.rvalue.callee, RefExpr
+ ):
+
+ python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
+ api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
+ )
+
+ if python_type_for_type is None:
+ return
+
+ else:
+ return
+
+ assert python_type_for_type is not None
+
+ attributes.append(
+ util.SQLAlchemyAttribute(
+ name=node.name,
+ line=stmt.line,
+ column=stmt.column,
+ typ=python_type_for_type,
+ info=cls.info,
+ )
+ )
+
+ apply.apply_type_to_mapped_statement(
+ api,
+ stmt,
+ lvalue,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
+
+
+def _scan_for_mapped_bases(
+ cls: ClassDef,
+ api: SemanticAnalyzerPluginInterface,
+) -> None:
+ """Given a class, iterate through its superclass hierarchy to find
+ all other classes that are considered as ORM-significant.
+
+ Locates non-mapped mixins and scans them for mapped attributes to be
+ applied to subclasses.
+
+ """
+
+ info = util.info_for_cls(cls, api)
+
+ if info is None:
+ return
+
+ for base_info in info.mro[1:-1]:
+ if base_info.fullname.startswith("builtins"):
+ continue
+
+ # scan each base for mapped attributes. if they are not already
+ # scanned (but have all their type info), that means they are unmapped
+ # mixins
+ scan_declarative_assignments_and_apply_types(
+ base_info.defn, api, is_mixin_scan=True
+ )
diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py
new file mode 100644
index 0000000..f88a960
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/infer.py
@@ -0,0 +1,556 @@
+# ext/mypy/infer.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import Optional
+from typing import Sequence
+
+from mypy.maptype import map_instance_to_supertype
+from mypy.messages import format_type
+from mypy.nodes import AssignmentStmt
+from mypy.nodes import CallExpr
+from mypy.nodes import Expression
+from mypy.nodes import FuncDef
+from mypy.nodes import LambdaExpr
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import RefExpr
+from mypy.nodes import StrExpr
+from mypy.nodes import TypeInfo
+from mypy.nodes import Var
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.subtypes import is_subtype
+from mypy.types import AnyType
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import ProperType
+from mypy.types import TypeOfAny
+from mypy.types import UnionType
+
+from . import names
+from . import util
+
+
+def infer_type_from_right_hand_nameexpr(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+ infer_from_right_side: RefExpr,
+) -> Optional[ProperType]:
+
+ type_id = names.type_id_for_callee(infer_from_right_side)
+
+ if type_id is None:
+ return None
+ elif type_id is names.COLUMN:
+ python_type_for_type = _infer_type_from_decl_column(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.RELATIONSHIP:
+ python_type_for_type = _infer_type_from_relationship(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.COLUMN_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_column_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ elif type_id is names.SYNONYM_PROPERTY:
+ python_type_for_type = infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif type_id is names.COMPOSITE_PROPERTY:
+ python_type_for_type = _infer_type_from_decl_composite_property(
+ api, stmt, node, left_hand_explicit_type
+ )
+ else:
+ return None
+
+ return python_type_for_type
+
+
+def _infer_type_from_relationship(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a relationship.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses = relationship(Address, uselist=True)
+
+ order: Mapped["Order"] = relationship("Order")
+
+ Will resolve in mypy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ addresses: Mapped[List[Address]]
+
+ order: Mapped["Order"]
+
+ """
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type: Optional[ProperType] = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ # type
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+
+ # other cases not covered - an error message directs the user
+ # to set an explicit type annotation
+ #
+ # node.type == str, it's a string
+ # if isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, Var
+ # )
+ # points to a type
+ # isinstance(target_cls_arg, NameExpr) and isinstance(
+ # target_cls_arg.node, TypeAlias
+ # )
+ # string expression
+ # isinstance(target_cls_arg, StrExpr)
+
+ uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist")
+ collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg(
+ stmt.rvalue, "collection_class"
+ )
+ type_is_a_collection = False
+
+ # this can be used to determine Optional for a many-to-one
+ # in the same way nullable=False could be used, if we start supporting
+ # that.
+ # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin")
+
+ if (
+ uselist_arg is not None
+ and api.parse_bool(uselist_arg) is True
+ and collection_cls_arg is None
+ ):
+ type_is_a_collection = True
+ if python_type_for_type is not None:
+ python_type_for_type = api.named_type(
+ names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type]
+ )
+ elif (
+ uselist_arg is None or api.parse_bool(uselist_arg) is True
+ ) and collection_cls_arg is not None:
+ type_is_a_collection = True
+ if isinstance(collection_cls_arg, CallExpr):
+ collection_cls_arg = collection_cls_arg.callee
+
+ if isinstance(collection_cls_arg, NameExpr) and isinstance(
+ collection_cls_arg.node, TypeInfo
+ ):
+ if python_type_for_type is not None:
+ # this can still be overridden by the left hand side
+ # within _infer_Type_from_left_and_inferred_right
+ python_type_for_type = Instance(
+ collection_cls_arg.node, [python_type_for_type]
+ )
+ elif (
+ isinstance(collection_cls_arg, NameExpr)
+ and isinstance(collection_cls_arg.node, FuncDef)
+ and collection_cls_arg.node.type is not None
+ ):
+ if python_type_for_type is not None:
+ # this can still be overridden by the left hand side
+ # within _infer_Type_from_left_and_inferred_right
+
+ # TODO: handle mypy.types.Overloaded
+ if isinstance(collection_cls_arg.node.type, CallableType):
+ rt = get_proper_type(collection_cls_arg.node.type.ret_type)
+
+ if isinstance(rt, CallableType):
+ callable_ret_type = get_proper_type(rt.ret_type)
+ if isinstance(callable_ret_type, Instance):
+ python_type_for_type = Instance(
+ callable_ret_type.type,
+ [python_type_for_type],
+ )
+ else:
+ util.fail(
+ api,
+ "Expected Python collection type for "
+ "collection_class parameter",
+ stmt.rvalue,
+ )
+ python_type_for_type = None
+ elif uselist_arg is not None and api.parse_bool(uselist_arg) is False:
+ if collection_cls_arg is not None:
+ util.fail(
+ api,
+ "Sending uselist=False and collection_class at the same time "
+ "does not make sense",
+ stmt.rvalue,
+ )
+ if python_type_for_type is not None:
+ python_type_for_type = UnionType(
+ [python_type_for_type, NoneType()]
+ )
+
+ else:
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer scalar or collection for ORM mapped expression "
+ "assigned to attribute '{}' if both 'uselist' and "
+ "'collection_class' arguments are absent from the "
+ "relationship(); please specify a "
+ "type annotation on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ if python_type_for_type is None:
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ if type_is_a_collection:
+ assert isinstance(left_hand_explicit_type, Instance)
+ assert isinstance(python_type_for_type, Instance)
+ return _infer_collection_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_explicit_type,
+ python_type_for_type,
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_composite_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a CompositeProperty."""
+
+ assert isinstance(stmt.rvalue, CallExpr)
+ target_cls_arg = stmt.rvalue.args[0]
+ python_type_for_type = None
+
+ if isinstance(target_cls_arg, NameExpr) and isinstance(
+ target_cls_arg.node, TypeInfo
+ ):
+ related_object_type = target_cls_arg.node
+ python_type_for_type = Instance(related_object_type, [])
+ else:
+ python_type_for_type = None
+
+ if python_type_for_type is None:
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+ elif left_hand_explicit_type is not None:
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+ else:
+ return python_type_for_type
+
+
+def _infer_type_from_decl_column_property(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a ColumnProperty.
+
+ This includes mappings against ``column_property()`` as well as the
+ ``deferred()`` function.
+
+ """
+ assert isinstance(stmt.rvalue, CallExpr)
+
+ if stmt.rvalue.args:
+ first_prop_arg = stmt.rvalue.args[0]
+
+ if isinstance(first_prop_arg, CallExpr):
+ type_id = names.type_id_for_callee(first_prop_arg.callee)
+
+ # look for column_property() / deferred() etc with Column as first
+ # argument
+ if type_id is names.COLUMN:
+ return _infer_type_from_decl_column(
+ api,
+ stmt,
+ node,
+ left_hand_explicit_type,
+ right_hand_expression=first_prop_arg,
+ )
+
+ if isinstance(stmt.rvalue, CallExpr):
+ type_id = names.type_id_for_callee(stmt.rvalue.callee)
+ # this is probably not strictly necessary as we have to use the left
+ # hand type for query expression in any case. any other no-arg
+ # column prop objects would go here also
+ if type_id is names.QUERY_EXPRESSION:
+ return _infer_type_from_decl_column(
+ api,
+ stmt,
+ node,
+ left_hand_explicit_type,
+ )
+
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_decl_column(
+ api: SemanticAnalyzerPluginInterface,
+ stmt: AssignmentStmt,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+ right_hand_expression: Optional[CallExpr] = None,
+) -> Optional[ProperType]:
+ """Infer the type of mapping from a Column.
+
+ E.g.::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a = Column(Integer)
+
+ b = Column("b", String)
+
+ c: Mapped[int] = Column(Integer)
+
+ d: bool = Column(Boolean)
+
+ Will resolve in MyPy as::
+
+ @reg.mapped
+ class MyClass:
+ # ...
+
+ a : Mapped[int]
+
+ b : Mapped[str]
+
+ c: Mapped[int]
+
+ d: Mapped[bool]
+
+ """
+ assert isinstance(node, Var)
+
+ callee = None
+
+ if right_hand_expression is None:
+ if not isinstance(stmt.rvalue, CallExpr):
+ return None
+
+ right_hand_expression = stmt.rvalue
+
+ for column_arg in right_hand_expression.args[0:2]:
+ if isinstance(column_arg, CallExpr):
+ if isinstance(column_arg.callee, RefExpr):
+ # x = Column(String(50))
+ callee = column_arg.callee
+ type_args: Sequence[Expression] = column_arg.args
+ break
+ elif isinstance(column_arg, (NameExpr, MemberExpr)):
+ if isinstance(column_arg.node, TypeInfo):
+ # x = Column(String)
+ callee = column_arg
+ type_args = ()
+ break
+ else:
+ # x = Column(some_name, String), go to next argument
+ continue
+ elif isinstance(column_arg, (StrExpr,)):
+ # x = Column("name", String), go to next argument
+ continue
+ elif isinstance(column_arg, (LambdaExpr,)):
+ # x = Column("name", String, default=lambda: uuid.uuid4())
+ # go to next argument
+ continue
+ else:
+ assert False
+
+ if callee is None:
+ return None
+
+ if isinstance(callee.node, TypeInfo) and names.mro_has_id(
+ callee.node.mro, names.TYPEENGINE
+ ):
+ python_type_for_type = extract_python_type_from_typeengine(
+ api, callee.node, type_args
+ )
+
+ if left_hand_explicit_type is not None:
+
+ return _infer_type_from_left_and_inferred_right(
+ api, node, left_hand_explicit_type, python_type_for_type
+ )
+
+ else:
+ return UnionType([python_type_for_type, NoneType()])
+ else:
+ # it's not TypeEngine, it's typically implicitly typed
+ # like ForeignKey. we can't infer from the right side.
+ return infer_type_from_left_hand_type_only(
+ api, node, left_hand_explicit_type
+ )
+
+
+def _infer_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: ProperType,
+ python_type_for_type: ProperType,
+ orig_left_hand_type: Optional[ProperType] = None,
+ orig_python_type_for_type: Optional[ProperType] = None,
+) -> Optional[ProperType]:
+ """Validate type when a left hand annotation is present and we also
+ could infer the right hand side::
+
+ attrname: SomeType = Column(SomeDBType)
+
+ """
+
+ if orig_left_hand_type is None:
+ orig_left_hand_type = left_hand_explicit_type
+ if orig_python_type_for_type is None:
+ orig_python_type_for_type = python_type_for_type
+
+ if not is_subtype(left_hand_explicit_type, python_type_for_type):
+ effective_type = api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type]
+ )
+
+ msg = (
+ "Left hand assignment '{}: {}' not compatible "
+ "with ORM mapped expression of type {}"
+ )
+ util.fail(
+ api,
+ msg.format(
+ node.name,
+ format_type(orig_left_hand_type),
+ format_type(effective_type),
+ ),
+ node,
+ )
+
+ return orig_left_hand_type
+
+
+def _infer_collection_type_from_left_and_inferred_right(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Instance,
+ python_type_for_type: Instance,
+) -> Optional[ProperType]:
+ orig_left_hand_type = left_hand_explicit_type
+ orig_python_type_for_type = python_type_for_type
+
+ if left_hand_explicit_type.args:
+ left_hand_arg = get_proper_type(left_hand_explicit_type.args[0])
+ python_type_arg = get_proper_type(python_type_for_type.args[0])
+ else:
+ left_hand_arg = left_hand_explicit_type
+ python_type_arg = python_type_for_type
+
+ assert isinstance(left_hand_arg, (Instance, UnionType))
+ assert isinstance(python_type_arg, (Instance, UnionType))
+
+ return _infer_type_from_left_and_inferred_right(
+ api,
+ node,
+ left_hand_arg,
+ python_type_arg,
+ orig_left_hand_type=orig_left_hand_type,
+ orig_python_type_for_type=orig_python_type_for_type,
+ )
+
+
+def infer_type_from_left_hand_type_only(
+ api: SemanticAnalyzerPluginInterface,
+ node: Var,
+ left_hand_explicit_type: Optional[ProperType],
+) -> Optional[ProperType]:
+ """Determine the type based on explicit annotation only.
+
+ if no annotation were present, note that we need one there to know
+ the type.
+
+ """
+ if left_hand_explicit_type is None:
+ msg = (
+ "Can't infer type from ORM mapped expression "
+ "assigned to attribute '{}'; please specify a "
+ "Python type or "
+ "Mapped[<python type>] on the left hand side."
+ )
+ util.fail(api, msg.format(node.name), node)
+
+ return api.named_type(
+ names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)]
+ )
+
+ else:
+ # use type from the left hand side
+ return left_hand_explicit_type
+
+
+def extract_python_type_from_typeengine(
+ api: SemanticAnalyzerPluginInterface,
+ node: TypeInfo,
+ type_args: Sequence[Expression],
+) -> ProperType:
+ if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args:
+ first_arg = type_args[0]
+ if isinstance(first_arg, RefExpr) and isinstance(
+ first_arg.node, TypeInfo
+ ):
+ for base_ in first_arg.node.mro:
+ if base_.fullname == "enum.Enum":
+ return Instance(first_arg.node, [])
+ # TODO: support other pep-435 types here
+ else:
+ return api.named_type(names.NAMED_TYPE_BUILTINS_STR, [])
+
+ assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), (
+ "could not extract Python type from node: %s" % node
+ )
+
+ type_engine_sym = api.lookup_fully_qualified_or_none(
+ "sqlalchemy.sql.type_api.TypeEngine"
+ )
+
+ assert type_engine_sym is not None and isinstance(
+ type_engine_sym.node, TypeInfo
+ )
+ type_engine = map_instance_to_supertype(
+ Instance(node, []),
+ type_engine_sym.node,
+ )
+ return get_proper_type(type_engine.args[-1])
diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py
new file mode 100644
index 0000000..8ec15a6
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/names.py
@@ -0,0 +1,253 @@
+# ext/mypy/names.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+from typing import Union
+
+from mypy.nodes import ClassDef
+from mypy.nodes import Expression
+from mypy.nodes import FuncDef
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolNode
+from mypy.nodes import TypeAlias
+from mypy.nodes import TypeInfo
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import CallableType
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import UnboundType
+
+from ... import util
+
+COLUMN: int = util.symbol("COLUMN") # type: ignore
+RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore
+REGISTRY: int = util.symbol("REGISTRY") # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
+TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore
+MAPPED: int = util.symbol("MAPPED") # type: ignore
+DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore
+DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore
+MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore
+COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore
+SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore
+COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore
+DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore
+MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore
+AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore
+AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore
+DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore
+QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore
+
+# names that must succeed with mypy.api.named_type
+NAMED_TYPE_BUILTINS_OBJECT = "builtins.object"
+NAMED_TYPE_BUILTINS_STR = "builtins.str"
+NAMED_TYPE_BUILTINS_LIST = "builtins.list"
+NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped"
+
+_lookup: Dict[str, Tuple[int, Set[str]]] = {
+ "Column": (
+ COLUMN,
+ {
+ "sqlalchemy.sql.schema.Column",
+ "sqlalchemy.sql.Column",
+ },
+ ),
+ "RelationshipProperty": (
+ RELATIONSHIP,
+ {
+ "sqlalchemy.orm.relationships.RelationshipProperty",
+ "sqlalchemy.orm.RelationshipProperty",
+ },
+ ),
+ "registry": (
+ REGISTRY,
+ {
+ "sqlalchemy.orm.decl_api.registry",
+ "sqlalchemy.orm.registry",
+ },
+ ),
+ "ColumnProperty": (
+ COLUMN_PROPERTY,
+ {
+ "sqlalchemy.orm.properties.ColumnProperty",
+ "sqlalchemy.orm.ColumnProperty",
+ },
+ ),
+ "SynonymProperty": (
+ SYNONYM_PROPERTY,
+ {
+ "sqlalchemy.orm.descriptor_props.SynonymProperty",
+ "sqlalchemy.orm.SynonymProperty",
+ },
+ ),
+ "CompositeProperty": (
+ COMPOSITE_PROPERTY,
+ {
+ "sqlalchemy.orm.descriptor_props.CompositeProperty",
+ "sqlalchemy.orm.CompositeProperty",
+ },
+ ),
+ "MapperProperty": (
+ MAPPER_PROPERTY,
+ {
+ "sqlalchemy.orm.interfaces.MapperProperty",
+ "sqlalchemy.orm.MapperProperty",
+ },
+ ),
+ "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}),
+ "Mapped": (MAPPED, {"sqlalchemy.orm.attributes.Mapped"}),
+ "declarative_base": (
+ DECLARATIVE_BASE,
+ {
+ "sqlalchemy.ext.declarative.declarative_base",
+ "sqlalchemy.orm.declarative_base",
+ "sqlalchemy.orm.decl_api.declarative_base",
+ },
+ ),
+ "DeclarativeMeta": (
+ DECLARATIVE_META,
+ {
+ "sqlalchemy.ext.declarative.DeclarativeMeta",
+ "sqlalchemy.orm.DeclarativeMeta",
+ "sqlalchemy.orm.decl_api.DeclarativeMeta",
+ },
+ ),
+ "mapped": (
+ MAPPED_DECORATOR,
+ {
+ "sqlalchemy.orm.decl_api.registry.mapped",
+ "sqlalchemy.orm.registry.mapped",
+ },
+ ),
+ "as_declarative": (
+ AS_DECLARATIVE,
+ {
+ "sqlalchemy.ext.declarative.as_declarative",
+ "sqlalchemy.orm.decl_api.as_declarative",
+ "sqlalchemy.orm.as_declarative",
+ },
+ ),
+ "as_declarative_base": (
+ AS_DECLARATIVE_BASE,
+ {
+ "sqlalchemy.orm.decl_api.registry.as_declarative_base",
+ "sqlalchemy.orm.registry.as_declarative_base",
+ },
+ ),
+ "declared_attr": (
+ DECLARED_ATTR,
+ {
+ "sqlalchemy.orm.decl_api.declared_attr",
+ "sqlalchemy.orm.declared_attr",
+ },
+ ),
+ "declarative_mixin": (
+ DECLARATIVE_MIXIN,
+ {
+ "sqlalchemy.orm.decl_api.declarative_mixin",
+ "sqlalchemy.orm.declarative_mixin",
+ },
+ ),
+ "query_expression": (
+ QUERY_EXPRESSION,
+ {"sqlalchemy.orm.query_expression"},
+ ),
+}
+
+
+def has_base_type_id(info: TypeInfo, type_id: int) -> bool:
+ for mr in info.mro:
+ check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+ if check_type_id == type_id:
+ break
+ else:
+ return False
+
+ if fullnames is None:
+ return False
+
+ return mr.fullname in fullnames
+
+
+def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool:
+ for mr in mro:
+ check_type_id, fullnames = _lookup.get(mr.name, (None, None))
+ if check_type_id == type_id:
+ break
+ else:
+ return False
+
+ if fullnames is None:
+ return False
+
+ return mr.fullname in fullnames
+
+
+def type_id_for_unbound_type(
+ type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> Optional[int]:
+ sym = api.lookup_qualified(type_.name, type_)
+ if sym is not None:
+ if isinstance(sym.node, TypeAlias):
+ target_type = get_proper_type(sym.node.target)
+ if isinstance(target_type, Instance):
+ return type_id_for_named_node(target_type.type)
+ elif isinstance(sym.node, TypeInfo):
+ return type_id_for_named_node(sym.node)
+
+ return None
+
+
+def type_id_for_callee(callee: Expression) -> Optional[int]:
+ if isinstance(callee, (MemberExpr, NameExpr)):
+ if isinstance(callee.node, FuncDef):
+ if callee.node.type and isinstance(callee.node.type, CallableType):
+ ret_type = get_proper_type(callee.node.type.ret_type)
+
+ if isinstance(ret_type, Instance):
+ return type_id_for_fullname(ret_type.type.fullname)
+
+ return None
+ elif isinstance(callee.node, TypeAlias):
+ target_type = get_proper_type(callee.node.target)
+ if isinstance(target_type, Instance):
+ return type_id_for_fullname(target_type.type.fullname)
+ elif isinstance(callee.node, TypeInfo):
+ return type_id_for_named_node(callee)
+ return None
+
+
+def type_id_for_named_node(
+ node: Union[NameExpr, MemberExpr, SymbolNode]
+) -> Optional[int]:
+ type_id, fullnames = _lookup.get(node.name, (None, None))
+
+ if type_id is None or fullnames is None:
+ return None
+ elif node.fullname in fullnames:
+ return type_id
+ else:
+ return None
+
+
+def type_id_for_fullname(fullname: str) -> Optional[int]:
+ tokens = fullname.split(".")
+ immediate = tokens[-1]
+
+ type_id, fullnames = _lookup.get(immediate, (None, None))
+
+ if type_id is None or fullnames is None:
+ return None
+ elif fullname in fullnames:
+ return type_id
+ else:
+ return None
diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py
new file mode 100644
index 0000000..8687012
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/plugin.py
@@ -0,0 +1,284 @@
+# ext/mypy/plugin.py
+# Copyright (C) 2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""
+Mypy plugin for SQLAlchemy ORM.
+
+"""
+from typing import Callable
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Type as TypingType
+from typing import Union
+
+from mypy import nodes
+from mypy.mro import calculate_mro
+from mypy.mro import MroError
+from mypy.nodes import Block
+from mypy.nodes import ClassDef
+from mypy.nodes import GDEF
+from mypy.nodes import MypyFile
+from mypy.nodes import NameExpr
+from mypy.nodes import SymbolTable
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
+from mypy.plugin import AttributeContext
+from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
+from mypy.plugin import Plugin
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.types import get_proper_type
+from mypy.types import Instance
+from mypy.types import Type
+
+from . import decl_class
+from . import names
+from . import util
+
+
+class SQLAlchemyPlugin(Plugin):
+ def get_dynamic_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[DynamicClassDefContext], None]]:
+ if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE:
+ return _dynamic_class_hook
+ return None
+
+ def get_customize_class_mro_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ return _fill_in_decorators
+
+ def get_class_decorator_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+
+ sym = self.lookup_fully_qualified(fullname)
+
+ if sym is not None and sym.node is not None:
+ type_id = names.type_id_for_named_node(sym.node)
+ if type_id is names.MAPPED_DECORATOR:
+ return _cls_decorator_hook
+ elif type_id in (
+ names.AS_DECLARATIVE,
+ names.AS_DECLARATIVE_BASE,
+ ):
+ return _base_cls_decorator_hook
+ elif type_id is names.DECLARATIVE_MIXIN:
+ return _declarative_mixin_hook
+
+ return None
+
+ def get_metaclass_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META:
+ # Set any classes that explicitly have metaclass=DeclarativeMeta
+ # as declarative so the check in `get_base_class_hook()` works
+ return _metaclass_cls_hook
+
+ return None
+
+ def get_base_class_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[ClassDefContext], None]]:
+ sym = self.lookup_fully_qualified(fullname)
+
+ if (
+ sym
+ and isinstance(sym.node, TypeInfo)
+ and util.has_declarative_base(sym.node)
+ ):
+ return _base_cls_hook
+
+ return None
+
+ def get_attribute_hook(
+ self, fullname: str
+ ) -> Optional[Callable[[AttributeContext], Type]]:
+ if fullname.startswith(
+ "sqlalchemy.orm.attributes.QueryableAttribute."
+ ):
+ return _queryable_getattr_hook
+
+ return None
+
+ def get_additional_deps(
+ self, file: MypyFile
+ ) -> List[Tuple[int, str, int]]:
+ return [
+ (10, "sqlalchemy.orm.attributes", -1),
+ (10, "sqlalchemy.orm.decl_api", -1),
+ ]
+
+
+def plugin(version: str) -> TypingType[SQLAlchemyPlugin]:
+ return SQLAlchemyPlugin
+
+
+def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None:
+ """Generate a declarative Base class when the declarative_base() function
+ is encountered."""
+
+ _add_globals(ctx)
+
+ cls = ClassDef(ctx.name, Block([]))
+ cls.fullname = ctx.api.qualified_name(ctx.name)
+
+ info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id)
+ cls.info = info
+ _set_declarative_metaclass(ctx.api, cls)
+
+ cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,))
+ if cls_arg is not None and isinstance(cls_arg.node, TypeInfo):
+ util.set_is_base(cls_arg.node)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ cls_arg.node.defn, ctx.api, is_mixin_scan=True
+ )
+ info.bases = [Instance(cls_arg.node, [])]
+ else:
+ obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
+
+ info.bases = [obj]
+
+ try:
+ calculate_mro(info)
+ except MroError:
+ util.fail(
+ ctx.api, "Not able to calculate MRO for declarative base", ctx.call
+ )
+ obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT)
+ info.bases = [obj]
+ info.fallback_to_any = True
+
+ ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
+ util.set_is_base(info)
+
+
+def _fill_in_decorators(ctx: ClassDefContext) -> None:
+ for decorator in ctx.cls.decorators:
+ # set the ".fullname" attribute of a class decorator
+ # that is a MemberExpr. This causes the logic in
+ # semanal.py->apply_class_plugin_hooks to invoke the
+ # get_class_decorator_hook for our "registry.map_class()"
+ # and "registry.as_declarative_base()" methods.
+ # this seems like a bug in mypy that these decorators are otherwise
+ # skipped.
+
+ if (
+ isinstance(decorator, nodes.CallExpr)
+ and isinstance(decorator.callee, nodes.MemberExpr)
+ and decorator.callee.name == "as_declarative_base"
+ ):
+ target = decorator.callee
+ elif (
+ isinstance(decorator, nodes.MemberExpr)
+ and decorator.name == "mapped"
+ ):
+ target = decorator
+ else:
+ continue
+
+ assert isinstance(target.expr, NameExpr)
+ sym = ctx.api.lookup_qualified(
+ target.expr.name, target, suppress_errors=True
+ )
+ if sym and sym.node:
+ sym_type = get_proper_type(sym.type)
+ if isinstance(sym_type, Instance):
+ target.fullname = f"{sym_type.type.fullname}.{target.name}"
+ else:
+ # if the registry is in the same file as where the
+ # decorator is used, it might not have semantic
+ # symbols applied and we can't get a fully qualified
+ # name or an inferred type, so we are actually going to
+ # flag an error in this case that they need to annotate
+ # it. The "registry" is declared just
+ # once (or few times), so they have to just not use
+ # type inference for its assignment in this one case.
+ util.fail(
+ ctx.api,
+ "Class decorator called %s(), but we can't "
+ "tell if it's from an ORM registry. Please "
+ "annotate the registry assignment, e.g. "
+ "my_registry: registry = registry()" % target.name,
+ sym.node,
+ )
+
+
+def _cls_decorator_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ assert isinstance(ctx.reason, nodes.MemberExpr)
+ expr = ctx.reason.expr
+
+ assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var)
+
+ node_type = get_proper_type(expr.node.type)
+
+ assert (
+ isinstance(node_type, Instance)
+ and names.type_id_for_named_node(node_type.type) is names.REGISTRY
+ )
+
+ decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _base_cls_decorator_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+
+ cls = ctx.cls
+
+ _set_declarative_metaclass(ctx.api, cls)
+
+ util.set_is_base(ctx.cls.info)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ cls, ctx.api, is_mixin_scan=True
+ )
+
+
+def _declarative_mixin_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ util.set_is_base(ctx.cls.info)
+ decl_class.scan_declarative_assignments_and_apply_types(
+ ctx.cls, ctx.api, is_mixin_scan=True
+ )
+
+
+def _metaclass_cls_hook(ctx: ClassDefContext) -> None:
+ util.set_is_base(ctx.cls.info)
+
+
+def _base_cls_hook(ctx: ClassDefContext) -> None:
+ _add_globals(ctx)
+ decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api)
+
+
+def _queryable_getattr_hook(ctx: AttributeContext) -> Type:
+ # how do I....tell it it has no attribute of a certain name?
+ # can't find any Type that seems to match that
+ return ctx.default_attr_type
+
+
+def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None:
+ """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space
+ for all class defs
+
+ """
+
+ util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped")
+
+
+def _set_declarative_metaclass(
+ api: SemanticAnalyzerPluginInterface, target_cls: ClassDef
+) -> None:
+ info = target_cls.info
+ sym = api.lookup_fully_qualified_or_none(
+ "sqlalchemy.orm.decl_api.DeclarativeMeta"
+ )
+ assert sym is not None and isinstance(sym.node, TypeInfo)
+ info.declared_metaclass = info.metaclass_type = Instance(sym.node, [])
diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py
new file mode 100644
index 0000000..16b365e
--- /dev/null
+++ b/lib/sqlalchemy/ext/mypy/util.py
@@ -0,0 +1,305 @@
+import re
+from typing import Any
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type as TypingType
+from typing import TypeVar
+from typing import Union
+
+from mypy.nodes import ARG_POS
+from mypy.nodes import CallExpr
+from mypy.nodes import ClassDef
+from mypy.nodes import CLASSDEF_NO_INFO
+from mypy.nodes import Context
+from mypy.nodes import Expression
+from mypy.nodes import IfStmt
+from mypy.nodes import JsonDict
+from mypy.nodes import MemberExpr
+from mypy.nodes import NameExpr
+from mypy.nodes import Statement
+from mypy.nodes import SymbolTableNode
+from mypy.nodes import TypeInfo
+from mypy.plugin import ClassDefContext
+from mypy.plugin import DynamicClassDefContext
+from mypy.plugin import SemanticAnalyzerPluginInterface
+from mypy.plugins.common import deserialize_and_fixup_type
+from mypy.typeops import map_type_from_supertype
+from mypy.types import Instance
+from mypy.types import NoneType
+from mypy.types import Type
+from mypy.types import TypeVarType
+from mypy.types import UnboundType
+from mypy.types import UnionType
+
+
+_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr])
+
+
+class SQLAlchemyAttribute:
+ def __init__(
+ self,
+ name: str,
+ line: int,
+ column: int,
+ typ: Optional[Type],
+ info: TypeInfo,
+ ) -> None:
+ self.name = name
+ self.line = line
+ self.column = column
+ self.type = typ
+ self.info = info
+
+ def serialize(self) -> JsonDict:
+ assert self.type
+ return {
+ "name": self.name,
+ "line": self.line,
+ "column": self.column,
+ "type": self.type.serialize(),
+ }
+
+ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
+ """Expands type vars in the context of a subtype when an attribute is
+ inherited from a generic super type.
+ """
+ if not isinstance(self.type, TypeVarType):
+ return
+
+ self.type = map_type_from_supertype(self.type, sub_type, self.info)
+
+ @classmethod
+ def deserialize(
+ cls,
+ info: TypeInfo,
+ data: JsonDict,
+ api: SemanticAnalyzerPluginInterface,
+ ) -> "SQLAlchemyAttribute":
+ data = data.copy()
+ typ = deserialize_and_fixup_type(data.pop("type"), api)
+ return cls(typ=typ, info=info, **data)
+
+
+def name_is_dunder(name):
+ return bool(re.match(r"^__.+?__$", name))
+
+
+def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None:
+ info.metadata.setdefault("sqlalchemy", {})[key] = data
+
+
+def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ return info.metadata.get("sqlalchemy", {}).get(key, None)
+
+
+def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]:
+ if info.mro:
+ for base in info.mro:
+ metadata = _get_info_metadata(base, key)
+ if metadata is not None:
+ return metadata
+ return None
+
+
+def establish_as_sqlalchemy(info: TypeInfo) -> None:
+ info.metadata.setdefault("sqlalchemy", {})
+
+
+def set_is_base(info: TypeInfo) -> None:
+ _set_info_metadata(info, "is_base", True)
+
+
+def get_is_base(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "is_base")
+ return is_base is True
+
+
+def has_declarative_base(info: TypeInfo) -> bool:
+ is_base = _get_info_mro_metadata(info, "is_base")
+ return is_base is True
+
+
+def set_has_table(info: TypeInfo) -> None:
+ _set_info_metadata(info, "has_table", True)
+
+
+def get_has_table(info: TypeInfo) -> bool:
+ is_base = _get_info_metadata(info, "has_table")
+ return is_base is True
+
+
+def get_mapped_attributes(
+ info: TypeInfo, api: SemanticAnalyzerPluginInterface
+) -> Optional[List[SQLAlchemyAttribute]]:
+ mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata(
+ info, "mapped_attributes"
+ )
+ if mapped_attributes is None:
+ return None
+
+ attributes: List[SQLAlchemyAttribute] = []
+
+ for data in mapped_attributes:
+ attr = SQLAlchemyAttribute.deserialize(info, data, api)
+ attr.expand_typevar_from_subtype(info)
+ attributes.append(attr)
+
+ return attributes
+
+
+def set_mapped_attributes(
+ info: TypeInfo, attributes: List[SQLAlchemyAttribute]
+) -> None:
+ _set_info_metadata(
+ info,
+ "mapped_attributes",
+ [attribute.serialize() for attribute in attributes],
+ )
+
+
+def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None:
+ msg = "[SQLAlchemy Mypy plugin] %s" % msg
+ return api.fail(msg, ctx)
+
+
+def add_global(
+ ctx: Union[ClassDefContext, DynamicClassDefContext],
+ module: str,
+ symbol_name: str,
+ asname: str,
+) -> None:
+ module_globals = ctx.api.modules[ctx.api.cur_mod_id].names
+
+ if asname not in module_globals:
+ lookup_sym: SymbolTableNode = ctx.api.modules[module].names[
+ symbol_name
+ ]
+
+ module_globals[asname] = lookup_sym
+
+
+@overload
+def get_callexpr_kwarg(
+ callexpr: CallExpr, name: str, *, expr_types: None = ...
+) -> Optional[Union[CallExpr, NameExpr]]:
+ ...
+
+
+@overload
+def get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Tuple[TypingType[_TArgType], ...]
+) -> Optional[_TArgType]:
+ ...
+
+
+def get_callexpr_kwarg(
+ callexpr: CallExpr,
+ name: str,
+ *,
+ expr_types: Optional[Tuple[TypingType[Any], ...]] = None
+) -> Optional[Any]:
+ try:
+ arg_idx = callexpr.arg_names.index(name)
+ except ValueError:
+ return None
+
+ kwarg = callexpr.args[arg_idx]
+ if isinstance(
+ kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr)
+ ):
+ return kwarg
+
+ return None
+
+
+def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]:
+ for stmt in stmts:
+ if (
+ isinstance(stmt, IfStmt)
+ and isinstance(stmt.expr[0], NameExpr)
+ and stmt.expr[0].fullname == "typing.TYPE_CHECKING"
+ ):
+ for substmt in stmt.body[0].body:
+ yield substmt
+ else:
+ yield stmt
+
+
+def unbound_to_instance(
+ api: SemanticAnalyzerPluginInterface, typ: Type
+) -> Type:
+ """Take the UnboundType that we seem to get as the ret_type from a FuncDef
+ and convert it into an Instance/TypeInfo kind of structure that seems
+ to work as the left-hand type of an AssignmentStatement.
+
+ """
+
+ if not isinstance(typ, UnboundType):
+ return typ
+
+ # TODO: figure out a more robust way to check this. The node is some
+ # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm,
+ # but I cant figure out how to get them to match up
+ if typ.name == "Optional":
+ # convert from "Optional?" to the more familiar
+ # UnionType[..., NoneType()]
+ return unbound_to_instance(
+ api,
+ UnionType(
+ [unbound_to_instance(api, typ_arg) for typ_arg in typ.args]
+ + [NoneType()]
+ ),
+ )
+
+ node = api.lookup_qualified(typ.name, typ)
+
+ if (
+ node is not None
+ and isinstance(node, SymbolTableNode)
+ and isinstance(node.node, TypeInfo)
+ ):
+ bound_type = node.node
+
+ return Instance(
+ bound_type,
+ [
+ unbound_to_instance(api, arg)
+ if isinstance(arg, UnboundType)
+ else arg
+ for arg in typ.args
+ ],
+ )
+ else:
+ return typ
+
+
+def info_for_cls(
+ cls: ClassDef, api: SemanticAnalyzerPluginInterface
+) -> Optional[TypeInfo]:
+ if cls.info is CLASSDEF_NO_INFO:
+ sym = api.lookup_qualified(cls.name, cls)
+ if sym is None:
+ return None
+ assert sym and isinstance(sym.node, TypeInfo)
+ return sym.node
+
+ return cls.info
+
+
+def expr_to_mapped_constructor(expr: Expression) -> CallExpr:
+ column_descriptor = NameExpr("__sa_Mapped")
+ column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped"
+ member_expr = MemberExpr(column_descriptor, "_empty_constructor")
+ return CallExpr(
+ member_expr,
+ [expr],
+ [ARG_POS],
+ ["arg1"],
+ )
diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py
new file mode 100644
index 0000000..5a327d1
--- /dev/null
+++ b/lib/sqlalchemy/ext/orderinglist.py
@@ -0,0 +1,388 @@
+# ext/orderinglist.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""A custom list that manages index/position information for contained
+elements.
+
+:author: Jason Kirtland
+
+``orderinglist`` is a helper for mutable ordered relationships. It will
+intercept list operations performed on a :func:`_orm.relationship`-managed
+collection and
+automatically synchronize changes in list position onto a target scalar
+attribute.
+
+Example: A ``slide`` table, where each row refers to zero or more entries
+in a related ``bullet`` table. The bullets within a slide are
+displayed in order based on the value of the ``position`` column in the
+``bullet`` table. As entries are reordered in memory, the value of the
+``position`` attribute should be updated to reflect the new sort order::
+
+
+ Base = declarative_base()
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position")
+
+ class Bullet(Base):
+ __tablename__ = 'bullet'
+ id = Column(Integer, primary_key=True)
+ slide_id = Column(Integer, ForeignKey('slide.id'))
+ position = Column(Integer)
+ text = Column(String)
+
+The standard relationship mapping will produce a list-like attribute on each
+``Slide`` containing all related ``Bullet`` objects,
+but coping with changes in ordering is not handled automatically.
+When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position``
+attribute will remain unset until manually assigned. When the ``Bullet``
+is inserted into the middle of the list, the following ``Bullet`` objects
+will also need to be renumbered.
+
+The :class:`.OrderingList` object automates this task, managing the
+``position`` attribute on all ``Bullet`` objects in the collection. It is
+constructed using the :func:`.ordering_list` factory::
+
+ from sqlalchemy.ext.orderinglist import ordering_list
+
+ Base = declarative_base()
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position",
+ collection_class=ordering_list('position'))
+
+ class Bullet(Base):
+ __tablename__ = 'bullet'
+ id = Column(Integer, primary_key=True)
+ slide_id = Column(Integer, ForeignKey('slide.id'))
+ position = Column(Integer)
+ text = Column(String)
+
+With the above mapping the ``Bullet.position`` attribute is managed::
+
+ s = Slide()
+ s.bullets.append(Bullet())
+ s.bullets.append(Bullet())
+ s.bullets[1].position
+ >>> 1
+ s.bullets.insert(1, Bullet())
+ s.bullets[2].position
+ >>> 2
+
+The :class:`.OrderingList` construct only works with **changes** to a
+collection, and not the initial load from the database, and requires that the
+list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the
+:func:`_orm.relationship` against the target ordering attribute, so that the
+ordering is correct when first loaded.
+
+.. warning::
+
+ :class:`.OrderingList` only provides limited functionality when a primary
+ key column or unique column is the target of the sort. Operations
+ that are unsupported or are problematic include:
+
+ * two entries must trade values. This is not supported directly in the
+ case of a primary key or unique constraint because it means at least
+ one row would need to be temporarily removed first, or changed to
+ a third, neutral value while the switch occurs.
+
+ * an entry must be deleted in order to make room for a new entry.
+ SQLAlchemy's unit of work performs all INSERTs before DELETEs within a
+ single flush. In the case of a primary key, it will trade
+ an INSERT/DELETE of the same primary key for an UPDATE statement in order
+ to lessen the impact of this limitation, however this does not take place
+ for a UNIQUE column.
+ A future feature will allow the "DELETE before INSERT" behavior to be
+ possible, alleviating this limitation, though this feature will require
+ explicit configuration at the mapper level for sets of columns that
+ are to be handled in this way.
+
+:func:`.ordering_list` takes the name of the related object's ordering
+attribute as an argument. By default, the zero-based integer index of the
+object's position in the :func:`.ordering_list` is synchronized with the
+ordering attribute: index 0 will get position 0, index 1 position 1, etc. To
+start numbering at 1 or some other integer, provide ``count_from=1``.
+
+
+"""
+from ..orm.collections import collection
+from ..orm.collections import collection_adapter
+
+
+__all__ = ["ordering_list"]
+
+
+def ordering_list(attr, count_from=None, **kw):
+ """Prepares an :class:`OrderingList` factory for use in mapper definitions.
+
+ Returns an object suitable for use as an argument to a Mapper
+ relationship's ``collection_class`` option. e.g.::
+
+ from sqlalchemy.ext.orderinglist import ordering_list
+
+ class Slide(Base):
+ __tablename__ = 'slide'
+
+ id = Column(Integer, primary_key=True)
+ name = Column(String)
+
+ bullets = relationship("Bullet", order_by="Bullet.position",
+ collection_class=ordering_list('position'))
+
+ :param attr:
+ Name of the mapped attribute to use for storage and retrieval of
+ ordering information
+
+ :param count_from:
+ Set up an integer-based ordering, starting at ``count_from``. For
+ example, ``ordering_list('pos', count_from=1)`` would create a 1-based
+ list in SQL, storing the value in the 'pos' column. Ignored if
+ ``ordering_func`` is supplied.
+
+ Additional arguments are passed to the :class:`.OrderingList` constructor.
+
+ """
+
+ kw = _unsugar_count_from(count_from=count_from, **kw)
+ return lambda: OrderingList(attr, **kw)
+
+
+# Ordering utility functions
+
+
+def count_from_0(index, collection):
+ """Numbering function: consecutive integers starting at 0."""
+
+ return index
+
+
+def count_from_1(index, collection):
+ """Numbering function: consecutive integers starting at 1."""
+
+ return index + 1
+
+
+def count_from_n_factory(start):
+ """Numbering function: consecutive integers starting at arbitrary start."""
+
+ def f(index, collection):
+ return index + start
+
+ try:
+ f.__name__ = "count_from_%i" % start
+ except TypeError:
+ pass
+ return f
+
+
+def _unsugar_count_from(**kw):
+ """Builds counting functions from keyword arguments.
+
+ Keyword argument filter, prepares a simple ``ordering_func`` from a
+ ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
+ """
+
+ count_from = kw.pop("count_from", None)
+ if kw.get("ordering_func", None) is None and count_from is not None:
+ if count_from == 0:
+ kw["ordering_func"] = count_from_0
+ elif count_from == 1:
+ kw["ordering_func"] = count_from_1
+ else:
+ kw["ordering_func"] = count_from_n_factory(count_from)
+ return kw
+
+
+class OrderingList(list):
+ """A custom list that manages position information for its children.
+
+ The :class:`.OrderingList` object is normally set up using the
+ :func:`.ordering_list` factory function, used in conjunction with
+ the :func:`_orm.relationship` function.
+
+ """
+
+ def __init__(
+ self, ordering_attr=None, ordering_func=None, reorder_on_append=False
+ ):
+ """A custom list that manages position information for its children.
+
+ ``OrderingList`` is a ``collection_class`` list implementation that
+ syncs position in a Python list with a position attribute on the
+ mapped objects.
+
+ This implementation relies on the list starting in the proper order,
+ so be **sure** to put an ``order_by`` on your relationship.
+
+ :param ordering_attr:
+ Name of the attribute that stores the object's order in the
+ relationship.
+
+ :param ordering_func: Optional. A function that maps the position in
+ the Python list to a value to store in the
+ ``ordering_attr``. Values returned are usually (but need not be!)
+ integers.
+
+ An ``ordering_func`` is called with two positional parameters: the
+ index of the element in the list, and the list itself.
+
+ If omitted, Python list indexes are used for the attribute values.
+ Two basic pre-built numbering functions are provided in this module:
+ ``count_from_0`` and ``count_from_1``. For more exotic examples
+ like stepped numbering, alphabetical and Fibonacci numbering, see
+ the unit tests.
+
+ :param reorder_on_append:
+ Default False. When appending an object with an existing (non-None)
+ ordering value, that value will be left untouched unless
+ ``reorder_on_append`` is true. This is an optimization to avoid a
+ variety of dangerous unexpected database writes.
+
+ SQLAlchemy will add instances to the list via append() when your
+ object loads. If for some reason the result set from the database
+ skips a step in the ordering (say, row '1' is missing but you get
+ '2', '3', and '4'), reorder_on_append=True would immediately
+ renumber the items to '1', '2', '3'. If you have multiple sessions
+ making changes, any of whom happen to load this collection even in
+ passing, all of the sessions would try to "clean up" the numbering
+ in their commits, possibly causing all but one to fail with a
+ concurrent modification error.
+
+ Recommend leaving this with the default of False, and just call
+ ``reorder()`` if you're doing ``append()`` operations with
+ previously ordered instances or when doing some housekeeping after
+ manual sql operations.
+
+ """
+ self.ordering_attr = ordering_attr
+ if ordering_func is None:
+ ordering_func = count_from_0
+ self.ordering_func = ordering_func
+ self.reorder_on_append = reorder_on_append
+
+ # More complex serialization schemes (multi column, e.g.) are possible by
+ # subclassing and reimplementing these two methods.
+ def _get_order_value(self, entity):
+ return getattr(entity, self.ordering_attr)
+
+ def _set_order_value(self, entity, value):
+ setattr(entity, self.ordering_attr, value)
+
+ def reorder(self):
+ """Synchronize ordering for the entire collection.
+
+ Sweeps through the list and ensures that each object has accurate
+ ordering information set.
+
+ """
+ for index, entity in enumerate(self):
+ self._order_entity(index, entity, True)
+
+ # As of 0.5, _reorder is no longer semi-private
+ _reorder = reorder
+
+ def _order_entity(self, index, entity, reorder=True):
+ have = self._get_order_value(entity)
+
+ # Don't disturb existing ordering if reorder is False
+ if have is not None and not reorder:
+ return
+
+ should_be = self.ordering_func(index, self)
+ if have != should_be:
+ self._set_order_value(entity, should_be)
+
+ def append(self, entity):
+ super(OrderingList, self).append(entity)
+ self._order_entity(len(self) - 1, entity, self.reorder_on_append)
+
+ def _raw_append(self, entity):
+ """Append without any ordering behavior."""
+
+ super(OrderingList, self).append(entity)
+
+ _raw_append = collection.adds(1)(_raw_append)
+
+ def insert(self, index, entity):
+ super(OrderingList, self).insert(index, entity)
+ self._reorder()
+
+ def remove(self, entity):
+ super(OrderingList, self).remove(entity)
+
+ adapter = collection_adapter(self)
+ if adapter and adapter._referenced_by_owner:
+ self._reorder()
+
+ def pop(self, index=-1):
+ entity = super(OrderingList, self).pop(index)
+ self._reorder()
+ return entity
+
+ def __setitem__(self, index, entity):
+ if isinstance(index, slice):
+ step = index.step or 1
+ start = index.start or 0
+ if start < 0:
+ start += len(self)
+ stop = index.stop or len(self)
+ if stop < 0:
+ stop += len(self)
+
+ for i in range(start, stop, step):
+ self.__setitem__(i, entity[i])
+ else:
+ self._order_entity(index, entity, True)
+ super(OrderingList, self).__setitem__(index, entity)
+
+ def __delitem__(self, index):
+ super(OrderingList, self).__delitem__(index)
+ self._reorder()
+
+ def __setslice__(self, start, end, values):
+ super(OrderingList, self).__setslice__(start, end, values)
+ self._reorder()
+
+ def __delslice__(self, start, end):
+ super(OrderingList, self).__delslice__(start, end)
+ self._reorder()
+
+ def __reduce__(self):
+ return _reconstitute, (self.__class__, self.__dict__, list(self))
+
+ for func_name, func in list(locals().items()):
+ if (
+ callable(func)
+ and func.__name__ == func_name
+ and not func.__doc__
+ and hasattr(list, func_name)
+ ):
+ func.__doc__ = getattr(list, func_name).__doc__
+ del func_name, func
+
+
+def _reconstitute(cls, dict_, items):
+ """Reconstitute an :class:`.OrderingList`.
+
+ This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for
+ unpickling :class:`.OrderingList` objects.
+
+ """
+ obj = cls.__new__(cls)
+ obj.__dict__.update(dict_)
+ list.extend(obj, items)
+ return obj
diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py
new file mode 100644
index 0000000..094b71b
--- /dev/null
+++ b/lib/sqlalchemy/ext/serializer.py
@@ -0,0 +1,177 @@
+# ext/serializer.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Serializer/Deserializer objects for usage with SQLAlchemy query structures,
+allowing "contextual" deserialization.
+
+Any SQLAlchemy query structure, either based on sqlalchemy.sql.*
+or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session
+etc. which are referenced by the structure are not persisted in serialized
+form, but are instead re-associated with the query structure
+when it is deserialized.
+
+Usage is nearly the same as that of the standard Python pickle module::
+
+ from sqlalchemy.ext.serializer import loads, dumps
+ metadata = MetaData(bind=some_engine)
+ Session = scoped_session(sessionmaker())
+
+ # ... define mappers
+
+ query = Session.query(MyClass).
+ filter(MyClass.somedata=='foo').order_by(MyClass.sortkey)
+
+ # pickle the query
+ serialized = dumps(query)
+
+ # unpickle. Pass in metadata + scoped_session
+ query2 = loads(serialized, metadata, Session)
+
+ print query2.all()
+
+Similar restrictions as when using raw pickle apply; mapped classes must be
+themselves be pickleable, meaning they are importable from a module-level
+namespace.
+
+The serializer module is only appropriate for query structures. It is not
+needed for:
+
+* instances of user-defined classes. These contain no references to engines,
+ sessions or expression constructs in the typical case and can be serialized
+ directly.
+
+* Table metadata that is to be loaded entirely from the serialized structure
+ (i.e. is not already declared in the application). Regular
+ pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object,
+ typically one which was reflected from an existing database at some previous
+ point in time. The serializer module is specifically for the opposite case,
+ where the Table metadata is already present in memory.
+
+"""
+
+import re
+
+from .. import Column
+from .. import Table
+from ..engine import Engine
+from ..orm import class_mapper
+from ..orm.interfaces import MapperProperty
+from ..orm.mapper import Mapper
+from ..orm.session import Session
+from ..util import b64decode
+from ..util import b64encode
+from ..util import byte_buffer
+from ..util import pickle
+from ..util import text_type
+
+
+__all__ = ["Serializer", "Deserializer", "dumps", "loads"]
+
+
+def Serializer(*args, **kw):
+ pickler = pickle.Pickler(*args, **kw)
+
+ def persistent_id(obj):
+ # print "serializing:", repr(obj)
+ if isinstance(obj, Mapper) and not obj.non_primary:
+ id_ = "mapper:" + b64encode(pickle.dumps(obj.class_))
+ elif isinstance(obj, MapperProperty) and not obj.parent.non_primary:
+ id_ = (
+ "mapperprop:"
+ + b64encode(pickle.dumps(obj.parent.class_))
+ + ":"
+ + obj.key
+ )
+ elif isinstance(obj, Table):
+ if "parententity" in obj._annotations:
+ id_ = "mapper_selectable:" + b64encode(
+ pickle.dumps(obj._annotations["parententity"].class_)
+ )
+ else:
+ id_ = "table:" + text_type(obj.key)
+ elif isinstance(obj, Column) and isinstance(obj.table, Table):
+ id_ = (
+ "column:" + text_type(obj.table.key) + ":" + text_type(obj.key)
+ )
+ elif isinstance(obj, Session):
+ id_ = "session:"
+ elif isinstance(obj, Engine):
+ id_ = "engine:"
+ else:
+ return None
+ return id_
+
+ pickler.persistent_id = persistent_id
+ return pickler
+
+
+our_ids = re.compile(
+ r"(mapperprop|mapper|mapper_selectable|table|column|"
+ r"session|attribute|engine):(.*)"
+)
+
+
+def Deserializer(file, metadata=None, scoped_session=None, engine=None):
+ unpickler = pickle.Unpickler(file)
+
+ def get_engine():
+ if engine:
+ return engine
+ elif scoped_session and scoped_session().bind:
+ return scoped_session().bind
+ elif metadata and metadata.bind:
+ return metadata.bind
+ else:
+ return None
+
+ def persistent_load(id_):
+ m = our_ids.match(text_type(id_))
+ if not m:
+ return None
+ else:
+ type_, args = m.group(1, 2)
+ if type_ == "attribute":
+ key, clsarg = args.split(":")
+ cls = pickle.loads(b64decode(clsarg))
+ return getattr(cls, key)
+ elif type_ == "mapper":
+ cls = pickle.loads(b64decode(args))
+ return class_mapper(cls)
+ elif type_ == "mapper_selectable":
+ cls = pickle.loads(b64decode(args))
+ return class_mapper(cls).__clause_element__()
+ elif type_ == "mapperprop":
+ mapper, keyname = args.split(":")
+ cls = pickle.loads(b64decode(mapper))
+ return class_mapper(cls).attrs[keyname]
+ elif type_ == "table":
+ return metadata.tables[args]
+ elif type_ == "column":
+ table, colname = args.split(":")
+ return metadata.tables[table].c[colname]
+ elif type_ == "session":
+ return scoped_session()
+ elif type_ == "engine":
+ return get_engine()
+ else:
+ raise Exception("Unknown token: %s" % type_)
+
+ unpickler.persistent_load = persistent_load
+ return unpickler
+
+
+def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL):
+ buf = byte_buffer()
+ pickler = Serializer(buf, protocol)
+ pickler.dump(obj)
+ return buf.getvalue()
+
+
+def loads(data, metadata=None, scoped_session=None, engine=None):
+ buf = byte_buffer(data)
+ unpickler = Deserializer(buf, metadata, scoped_session, engine)
+ return unpickler.load()