diff options
Diffstat (limited to 'lib/sqlalchemy/ext')
29 files changed, 13167 insertions, 0 deletions
diff --git a/lib/sqlalchemy/ext/__init__.py b/lib/sqlalchemy/ext/__init__.py new file mode 100644 index 0000000..62bbbf3 --- /dev/null +++ b/lib/sqlalchemy/ext/__init__.py @@ -0,0 +1,11 @@ +# ext/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .. import util as _sa_util + + +_sa_util.preloaded.import_prefix("sqlalchemy.ext") diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py new file mode 100644 index 0000000..fbf377a --- /dev/null +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -0,0 +1,1627 @@ +# ext/associationproxy.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Contain the ``AssociationProxy`` class. + +The ``AssociationProxy`` is a Python property object which provides +transparent proxied access to the endpoint of an association object. + +See the example ``examples/association/proxied_association.py``. + +""" +import operator + +from .. import exc +from .. import inspect +from .. import orm +from .. import util +from ..orm import collections +from ..orm import interfaces +from ..sql import or_ +from ..sql.operators import ColumnOperators + + +def association_proxy(target_collection, attr, **kw): + r"""Return a Python property implementing a view of a target + attribute which references an attribute on members of the + target. + + The returned value is an instance of :class:`.AssociationProxy`. + + Implements a Python property representing a relationship as a collection + of simpler values, or a scalar value. The proxied property will mimic + the collection type of the target (list, dict or set), or, in the case of + a one to one relationship, a simple scalar value. + + :param target_collection: Name of the attribute we'll proxy to. + This attribute is typically mapped by + :func:`~sqlalchemy.orm.relationship` to link to a target collection, but + can also be a many-to-one or non-scalar relationship. + + :param attr: Attribute on the associated instance or instances we'll + proxy for. + + For example, given a target collection of [obj1, obj2], a list created + by this proxy property would look like [getattr(obj1, *attr*), + getattr(obj2, *attr*)] + + If the relationship is one-to-one or otherwise uselist=False, then + simply: getattr(obj, *attr*) + + :param creator: optional. + + When new items are added to this proxied collection, new instances of + the class collected by the target collection will be created. For list + and set collections, the target class constructor will be called with + the 'value' for the new instance. For dict types, two arguments are + passed: key and value. + + If you want to construct instances differently, supply a *creator* + function that takes arguments as above and returns instances. + + For scalar relationships, creator() will be called if the target is None. + If the target is present, set operations are proxied to setattr() on the + associated object. + + If you have an associated object with multiple attributes, you may set + up multiple association proxies mapping to different attributes. See + the unit tests for examples, and for examples of how creator() functions + can be used to construct the scalar relationship on-demand in this + situation. + + :param \*\*kw: Passes along any other keyword arguments to + :class:`.AssociationProxy`. + + """ + return AssociationProxy(target_collection, attr, **kw) + + +ASSOCIATION_PROXY = util.symbol("ASSOCIATION_PROXY") +"""Symbol indicating an :class:`.InspectionAttr` that's + of type :class:`.AssociationProxy`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + +""" + + +class AssociationProxy(interfaces.InspectionAttrInfo): + """A descriptor that presents a read/write view of an object attribute.""" + + is_attribute = True + extension_type = ASSOCIATION_PROXY + + def __init__( + self, + target_collection, + attr, + creator=None, + getset_factory=None, + proxy_factory=None, + proxy_bulk_set=None, + info=None, + cascade_scalar_deletes=False, + ): + """Construct a new :class:`.AssociationProxy`. + + The :func:`.association_proxy` function is provided as the usual + entrypoint here, though :class:`.AssociationProxy` can be instantiated + and/or subclassed directly. + + :param target_collection: Name of the collection we'll proxy to, + usually created with :func:`_orm.relationship`. + + :param attr: Attribute on the collected instances we'll proxy + for. For example, given a target collection of [obj1, obj2], a + list created by this proxy property would look like + [getattr(obj1, attr), getattr(obj2, attr)] + + :param creator: Optional. When new items are added to this proxied + collection, new instances of the class collected by the target + collection will be created. For list and set collections, the + target class constructor will be called with the 'value' for the + new instance. For dict types, two arguments are passed: + key and value. + + If you want to construct instances differently, supply a 'creator' + function that takes arguments as above and returns instances. + + :param cascade_scalar_deletes: when True, indicates that setting + the proxied value to ``None``, or deleting it via ``del``, should + also remove the source object. Only applies to scalar attributes. + Normally, removing the proxied target will not remove the proxy + source, as this object may have other state that is still to be + kept. + + .. versionadded:: 1.3 + + .. seealso:: + + :ref:`cascade_scalar_deletes` - complete usage example + + :param getset_factory: Optional. Proxied attribute access is + automatically handled by routines that get and set values based on + the `attr` argument for this proxy. + + If you would like to customize this behavior, you may supply a + `getset_factory` callable that produces a tuple of `getter` and + `setter` functions. The factory is called with two arguments, the + abstract type of the underlying collection and this proxy instance. + + :param proxy_factory: Optional. The type of collection to emulate is + determined by sniffing the target collection. If your collection + type can't be determined by duck typing or you'd like to use a + different collection implementation, you may supply a factory + function to produce those collections. Only applicable to + non-scalar relationships. + + :param proxy_bulk_set: Optional, use with proxy_factory. See + the _set() method for details. + + :param info: optional, will be assigned to + :attr:`.AssociationProxy.info` if present. + + .. versionadded:: 1.0.9 + + """ + self.target_collection = target_collection + self.value_attr = attr + self.creator = creator + self.getset_factory = getset_factory + self.proxy_factory = proxy_factory + self.proxy_bulk_set = proxy_bulk_set + self.cascade_scalar_deletes = cascade_scalar_deletes + + self.key = "_%s_%s_%s" % ( + type(self).__name__, + target_collection, + id(self), + ) + if info: + self.info = info + + def __get__(self, obj, class_): + if class_ is None: + return self + inst = self._as_instance(class_, obj) + if inst: + return inst.get(obj) + + # obj has to be None here + # assert obj is None + + return self + + def __set__(self, obj, values): + class_ = type(obj) + return self._as_instance(class_, obj).set(obj, values) + + def __delete__(self, obj): + class_ = type(obj) + return self._as_instance(class_, obj).delete(obj) + + def for_class(self, class_, obj=None): + r"""Return the internal state local to a specific mapped class. + + E.g., given a class ``User``:: + + class User(Base): + # ... + + keywords = association_proxy('kws', 'keyword') + + If we access this :class:`.AssociationProxy` from + :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the + target class for this proxy as mapped by ``User``:: + + inspect(User).all_orm_descriptors["keywords"].for_class(User).target_class + + This returns an instance of :class:`.AssociationProxyInstance` that + is specific to the ``User`` class. The :class:`.AssociationProxy` + object remains agnostic of its parent class. + + :param class\_: the class that we are returning state for. + + :param obj: optional, an instance of the class that is required + if the attribute refers to a polymorphic target, e.g. where we have + to look at the type of the actual destination object to get the + complete path. + + .. versionadded:: 1.3 - :class:`.AssociationProxy` no longer stores + any state specific to a particular parent class; the state is now + stored in per-class :class:`.AssociationProxyInstance` objects. + + + """ + return self._as_instance(class_, obj) + + def _as_instance(self, class_, obj): + try: + inst = class_.__dict__[self.key + "_inst"] + except KeyError: + inst = None + + # avoid exception context + if inst is None: + owner = self._calc_owner(class_) + if owner is not None: + inst = AssociationProxyInstance.for_proxy(self, owner, obj) + setattr(class_, self.key + "_inst", inst) + else: + inst = None + + if inst is not None and not inst._is_canonical: + # the AssociationProxyInstance can't be generalized + # since the proxied attribute is not on the targeted + # class, only on subclasses of it, which might be + # different. only return for the specific + # object's current value + return inst._non_canonical_get_for_object(obj) + else: + return inst + + def _calc_owner(self, target_cls): + # we might be getting invoked for a subclass + # that is not mapped yet, in some declarative situations. + # save until we are mapped + try: + insp = inspect(target_cls) + except exc.NoInspectionAvailable: + # can't find a mapper, don't set owner. if we are a not-yet-mapped + # subclass, we can also scan through __mro__ to find a mapped + # class, but instead just wait for us to be called again against a + # mapped class normally. + return None + else: + return insp.mapper.class_manager.class_ + + def _default_getset(self, collection_class): + attr = self.value_attr + _getter = operator.attrgetter(attr) + + def getter(target): + return _getter(target) if target is not None else None + + if collection_class is dict: + + def setter(o, k, v): + setattr(o, attr, v) + + else: + + def setter(o, v): + setattr(o, attr, v) + + return getter, setter + + def __repr__(self): + return "AssociationProxy(%r, %r)" % ( + self.target_collection, + self.value_attr, + ) + + +class AssociationProxyInstance(object): + """A per-class object that serves class- and object-specific results. + + This is used by :class:`.AssociationProxy` when it is invoked + in terms of a specific class or instance of a class, i.e. when it is + used as a regular Python descriptor. + + When referring to the :class:`.AssociationProxy` as a normal Python + descriptor, the :class:`.AssociationProxyInstance` is the object that + actually serves the information. Under normal circumstances, its presence + is transparent:: + + >>> User.keywords.scalar + False + + In the special case that the :class:`.AssociationProxy` object is being + accessed directly, in order to get an explicit handle to the + :class:`.AssociationProxyInstance`, use the + :meth:`.AssociationProxy.for_class` method:: + + proxy_state = inspect(User).all_orm_descriptors["keywords"].for_class(User) + + # view if proxy object is scalar or not + >>> proxy_state.scalar + False + + .. versionadded:: 1.3 + + """ # noqa + + def __init__(self, parent, owning_class, target_class, value_attr): + self.parent = parent + self.key = parent.key + self.owning_class = owning_class + self.target_collection = parent.target_collection + self.collection_class = None + self.target_class = target_class + self.value_attr = value_attr + + target_class = None + """The intermediary class handled by this + :class:`.AssociationProxyInstance`. + + Intercepted append/set/assignment events will result + in the generation of new instances of this class. + + """ + + @classmethod + def for_proxy(cls, parent, owning_class, parent_instance): + target_collection = parent.target_collection + value_attr = parent.value_attr + prop = orm.class_mapper(owning_class).get_property(target_collection) + + # this was never asserted before but this should be made clear. + if not isinstance(prop, orm.RelationshipProperty): + util.raise_( + NotImplementedError( + "association proxy to a non-relationship " + "intermediary is not supported" + ), + replace_context=None, + ) + + target_class = prop.mapper.class_ + + try: + target_assoc = cls._cls_unwrap_target_assoc_proxy( + target_class, value_attr + ) + except AttributeError: + # the proxied attribute doesn't exist on the target class; + # return an "ambiguous" instance that will work on a per-object + # basis + return AmbiguousAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + except Exception as err: + util.raise_( + exc.InvalidRequestError( + "Association proxy received an unexpected error when " + "trying to retreive attribute " + '"%s.%s" from ' + 'class "%s": %s' + % ( + target_class.__name__, + parent.value_attr, + target_class.__name__, + err, + ) + ), + from_=err, + ) + else: + return cls._construct_for_assoc( + target_assoc, parent, owning_class, target_class, value_attr + ) + + @classmethod + def _construct_for_assoc( + cls, target_assoc, parent, owning_class, target_class, value_attr + ): + if target_assoc is not None: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + attr = getattr(target_class, value_attr) + if not hasattr(attr, "_is_internal_proxy"): + return AmbiguousAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + is_object = attr._impl_uses_objects + if is_object: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + else: + return ColumnAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + def _get_property(self): + return orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) + + @property + def _comparator(self): + return self._get_property().comparator + + def __clause_element__(self): + raise NotImplementedError( + "The association proxy can't be used as a plain column " + "expression; it only works inside of a comparison expression" + ) + + @classmethod + def _cls_unwrap_target_assoc_proxy(cls, target_class, value_attr): + attr = getattr(target_class, value_attr) + if isinstance(attr, (AssociationProxy, AssociationProxyInstance)): + return attr + return None + + @util.memoized_property + def _unwrap_target_assoc_proxy(self): + return self._cls_unwrap_target_assoc_proxy( + self.target_class, self.value_attr + ) + + @property + def remote_attr(self): + """The 'remote' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.local_attr` + + """ + return getattr(self.target_class, self.value_attr) + + @property + def local_attr(self): + """The 'local' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return getattr(self.owning_class, self.target_collection) + + @property + def attr(self): + """Return a tuple of ``(local_attr, remote_attr)``. + + This attribute was originally intended to facilitate using the + :meth:`_query.Query.join` method to join across the two relationships + at once, however this makes use of a deprecated calling style. + + To use :meth:`_sql.select.join` or :meth:`_orm.Query.join` with + an association proxy, the current method is to make use of the + :attr:`.AssociationProxyInstance.local_attr` and + :attr:`.AssociationProxyInstance.remote_attr` attributes separately:: + + stmt = ( + select(Parent). + join(Parent.proxied.local_attr). + join(Parent.proxied.remote_attr) + ) + + A future release may seek to provide a more succinct join pattern + for association proxy attributes. + + .. seealso:: + + :attr:`.AssociationProxyInstance.local_attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return (self.local_attr, self.remote_attr) + + @util.memoized_property + def scalar(self): + """Return ``True`` if this :class:`.AssociationProxyInstance` + proxies a scalar relationship on the local side.""" + + scalar = not self._get_property().uselist + if scalar: + self._initialize_scalar_accessors() + return scalar + + @util.memoized_property + def _value_is_scalar(self): + return ( + not self._get_property() + .mapper.get_property(self.value_attr) + .uselist + ) + + @property + def _target_is_object(self): + raise NotImplementedError() + + def _initialize_scalar_accessors(self): + if self.parent.getset_factory: + get, set_ = self.parent.getset_factory(None, self) + else: + get, set_ = self.parent._default_getset(None) + self._scalar_get, self._scalar_set = get, set_ + + def _default_getset(self, collection_class): + attr = self.value_attr + _getter = operator.attrgetter(attr) + + def getter(target): + return _getter(target) if target is not None else None + + if collection_class is dict: + + def setter(o, k, v): + return setattr(o, attr, v) + + else: + + def setter(o, v): + return setattr(o, attr, v) + + return getter, setter + + @property + def info(self): + return self.parent.info + + def get(self, obj): + if obj is None: + return self + + if self.scalar: + target = getattr(obj, self.target_collection) + return self._scalar_get(target) + else: + try: + # If the owning instance is reborn (orm session resurrect, + # etc.), refresh the proxy cache. + creator_id, self_id, proxy = getattr(obj, self.key) + except AttributeError: + pass + else: + if id(obj) == creator_id and id(self) == self_id: + assert self.collection_class is not None + return proxy + + self.collection_class, proxy = self._new( + _lazy_collection(obj, self.target_collection) + ) + setattr(obj, self.key, (id(obj), id(self), proxy)) + return proxy + + def set(self, obj, values): + if self.scalar: + creator = ( + self.parent.creator + if self.parent.creator + else self.target_class + ) + target = getattr(obj, self.target_collection) + if target is None: + if values is None: + return + setattr(obj, self.target_collection, creator(values)) + else: + self._scalar_set(target, values) + if values is None and self.parent.cascade_scalar_deletes: + setattr(obj, self.target_collection, None) + else: + proxy = self.get(obj) + assert self.collection_class is not None + if proxy is not values: + proxy._bulk_replace(self, values) + + def delete(self, obj): + if self.owning_class is None: + self._calc_owner(obj, None) + + if self.scalar: + target = getattr(obj, self.target_collection) + if target is not None: + delattr(target, self.value_attr) + delattr(obj, self.target_collection) + + def _new(self, lazy_collection): + creator = ( + self.parent.creator if self.parent.creator else self.target_class + ) + collection_class = util.duck_type_collection(lazy_collection()) + + if self.parent.proxy_factory: + return ( + collection_class, + self.parent.proxy_factory( + lazy_collection, creator, self.value_attr, self + ), + ) + + if self.parent.getset_factory: + getter, setter = self.parent.getset_factory(collection_class, self) + else: + getter, setter = self.parent._default_getset(collection_class) + + if collection_class is list: + return ( + collection_class, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), + ) + elif collection_class is dict: + return ( + collection_class, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), + ) + elif collection_class is set: + return ( + collection_class, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), + ) + else: + raise exc.ArgumentError( + "could not guess which interface to use for " + 'collection_class "%s" backing "%s"; specify a ' + "proxy_factory and proxy_bulk_set manually" + % (self.collection_class.__name__, self.target_collection) + ) + + def _set(self, proxy, values): + if self.parent.proxy_bulk_set: + self.parent.proxy_bulk_set(proxy, values) + elif self.collection_class is list: + proxy.extend(values) + elif self.collection_class is dict: + proxy.update(values) + elif self.collection_class is set: + proxy.update(values) + else: + raise exc.ArgumentError( + "no proxy_bulk_set supplied for custom " + "collection_class implementation" + ) + + def _inflate(self, proxy): + creator = ( + self.parent.creator and self.parent.creator or self.target_class + ) + + if self.parent.getset_factory: + getter, setter = self.parent.getset_factory( + self.collection_class, self + ) + else: + getter, setter = self.parent._default_getset(self.collection_class) + + proxy.creator = creator + proxy.getter = getter + proxy.setter = setter + + def _criterion_exists(self, criterion=None, **kwargs): + is_has = kwargs.pop("is_has", None) + + target_assoc = self._unwrap_target_assoc_proxy + if target_assoc is not None: + inner = target_assoc._criterion_exists( + criterion=criterion, **kwargs + ) + return self._comparator._criterion_exists(inner) + + if self._target_is_object: + prop = getattr(self.target_class, self.value_attr) + value_expr = prop._criterion_exists(criterion, **kwargs) + else: + if kwargs: + raise exc.ArgumentError( + "Can't apply keyword arguments to column-targeted " + "association proxy; use ==" + ) + elif is_has and criterion is not None: + raise exc.ArgumentError( + "Non-empty has() not allowed for " + "column-targeted association proxy; use ==" + ) + + value_expr = criterion + + return self._comparator._criterion_exists(value_expr) + + def any(self, criterion=None, **kwargs): + """Produce a proxied 'any' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any` + and/or :meth:`.RelationshipProperty.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._unwrap_target_assoc_proxy is None and ( + self.scalar + and (not self._target_is_object or self._value_is_scalar) + ): + raise exc.InvalidRequestError( + "'any()' not implemented for scalar " "attributes. Use has()." + ) + return self._criterion_exists( + criterion=criterion, is_has=False, **kwargs + ) + + def has(self, criterion=None, **kwargs): + """Produce a proxied 'has' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any` + and/or :meth:`.RelationshipProperty.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._unwrap_target_assoc_proxy is None and ( + not self.scalar + or (self._target_is_object and not self._value_is_scalar) + ): + raise exc.InvalidRequestError( + "'has()' not implemented for collections. " "Use any()." + ) + return self._criterion_exists( + criterion=criterion, is_has=True, **kwargs + ) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.parent) + + +class AmbiguousAssociationProxyInstance(AssociationProxyInstance): + """an :class:`.AssociationProxyInstance` where we cannot determine + the type of target object. + """ + + _is_canonical = False + + def _ambiguous(self): + raise AttributeError( + "Association proxy %s.%s refers to an attribute '%s' that is not " + "directly mapped on class %s; therefore this operation cannot " + "proceed since we don't know what type of object is referred " + "towards" + % ( + self.owning_class.__name__, + self.target_collection, + self.value_attr, + self.target_class, + ) + ) + + def get(self, obj): + if obj is None: + return self + else: + return super(AmbiguousAssociationProxyInstance, self).get(obj) + + def __eq__(self, obj): + self._ambiguous() + + def __ne__(self, obj): + self._ambiguous() + + def any(self, criterion=None, **kwargs): + self._ambiguous() + + def has(self, criterion=None, **kwargs): + self._ambiguous() + + @util.memoized_property + def _lookup_cache(self): + # mapping of <subclass>->AssociationProxyInstance. + # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist; + # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2 + return {} + + def _non_canonical_get_for_object(self, parent_instance): + if parent_instance is not None: + actual_obj = getattr(parent_instance, self.target_collection) + if actual_obj is not None: + try: + insp = inspect(actual_obj) + except exc.NoInspectionAvailable: + pass + else: + mapper = insp.mapper + instance_class = mapper.class_ + if instance_class not in self._lookup_cache: + self._populate_cache(instance_class, mapper) + + try: + return self._lookup_cache[instance_class] + except KeyError: + pass + + # no object or ambiguous object given, so return "self", which + # is a proxy with generally only instance-level functionality + return self + + def _populate_cache(self, instance_class, mapper): + prop = orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) + + if mapper.isa(prop.mapper): + target_class = instance_class + try: + target_assoc = self._cls_unwrap_target_assoc_proxy( + target_class, self.value_attr + ) + except AttributeError: + pass + else: + self._lookup_cache[instance_class] = self._construct_for_assoc( + target_assoc, + self.parent, + self.owning_class, + target_class, + self.value_attr, + ) + + +class ObjectAssociationProxyInstance(AssociationProxyInstance): + """an :class:`.AssociationProxyInstance` that has an object as a target.""" + + _target_is_object = True + _is_canonical = True + + def contains(self, obj): + """Produce a proxied 'contains' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.RelationshipProperty.Comparator.any`, + :meth:`.RelationshipProperty.Comparator.has`, + and/or :meth:`.RelationshipProperty.Comparator.contains` + operators of the underlying proxied attributes. + """ + + target_assoc = self._unwrap_target_assoc_proxy + if target_assoc is not None: + return self._comparator._criterion_exists( + target_assoc.contains(obj) + if not target_assoc.scalar + else target_assoc == obj + ) + elif ( + self._target_is_object + and self.scalar + and not self._value_is_scalar + ): + return self._comparator.has( + getattr(self.target_class, self.value_attr).contains(obj) + ) + elif self._target_is_object and self.scalar and self._value_is_scalar: + raise exc.InvalidRequestError( + "contains() doesn't apply to a scalar object endpoint; use ==" + ) + else: + + return self._comparator._criterion_exists(**{self.value_attr: obj}) + + def __eq__(self, obj): + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + if obj is None: + return or_( + self._comparator.has(**{self.value_attr: obj}), + self._comparator == None, + ) + else: + return self._comparator.has(**{self.value_attr: obj}) + + def __ne__(self, obj): + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + return self._comparator.has( + getattr(self.target_class, self.value_attr) != obj + ) + + +class ColumnAssociationProxyInstance( + ColumnOperators, AssociationProxyInstance +): + """an :class:`.AssociationProxyInstance` that has a database column as a + target. + """ + + _target_is_object = False + _is_canonical = True + + def __eq__(self, other): + # special case "is None" to check for no related row as well + expr = self._criterion_exists( + self.remote_attr.operate(operator.eq, other) + ) + if other is None: + return or_(expr, self._comparator == None) + else: + return expr + + def operate(self, op, *other, **kwargs): + return self._criterion_exists( + self.remote_attr.operate(op, *other, **kwargs) + ) + + +class _lazy_collection(object): + def __init__(self, obj, target): + self.parent = obj + self.target = target + + def __call__(self): + return getattr(self.parent, self.target) + + def __getstate__(self): + return {"obj": self.parent, "target": self.target} + + def __setstate__(self, state): + self.parent = state["obj"] + self.target = state["target"] + + +class _AssociationCollection(object): + def __init__(self, lazy_collection, creator, getter, setter, parent): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. + + lazy_collection + A callable returning a list-based collection of entities (usually an + object attribute managed by a SQLAlchemy relationship()) + + creator + A function that creates new target entities. Given one parameter: + value. This assertion is assumed:: + + obj = creator(somevalue) + assert getter(obj) == somevalue + + getter + A function. Given an associated object, return the 'value'. + + setter + A function. Given an associated object and a value, store that + value on the object. + + """ + self.lazy_collection = lazy_collection + self.creator = creator + self.getter = getter + self.setter = setter + self.parent = parent + + col = property(lambda self: self.lazy_collection()) + + def __len__(self): + return len(self.col) + + def __bool__(self): + return bool(self.col) + + __nonzero__ = __bool__ + + def __getstate__(self): + return {"parent": self.parent, "lazy_collection": self.lazy_collection} + + def __setstate__(self, state): + self.parent = state["parent"] + self.lazy_collection = state["lazy_collection"] + self.parent._inflate(self) + + def _bulk_replace(self, assoc_proxy, values): + self.clear() + assoc_proxy._set(self, values) + + +class _AssociationList(_AssociationCollection): + """Generic, converting, list-to-list proxy.""" + + def _create(self, value): + return self.creator(value) + + def _get(self, object_): + return self.getter(object_) + + def _set(self, object_, value): + return self.setter(object_, value) + + def __getitem__(self, index): + if not isinstance(index, slice): + return self._get(self.col[index]) + else: + return [self._get(member) for member in self.col[index]] + + def __setitem__(self, index, value): + if not isinstance(index, slice): + self._set(self.col[index], value) + else: + if index.stop is None: + stop = len(self) + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + + start = index.start or 0 + rng = list(range(index.start or 0, stop, step)) + if step == 1: + for i in rng: + del self[start] + i = start + for item in value: + self.insert(i, item) + i += 1 + else: + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" % (len(value), len(rng)) + ) + for i, item in zip(rng, value): + self._set(self.col[i], item) + + def __delitem__(self, index): + del self.col[index] + + def __contains__(self, value): + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __getslice__(self, start, end): + return [self._get(member) for member in self.col[start:end]] + + def __setslice__(self, start, end, values): + members = [self._create(v) for v in values] + self.col[start:end] = members + + def __delslice__(self, start, end): + del self.col[start:end] + + def __iter__(self): + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or + just use the underlying collection directly from its property + on the parent. + """ + + for member in self.col: + yield self._get(member) + return + + def append(self, value): + col = self.col + item = self._create(value) + col.append(item) + + def count(self, value): + return sum( + [ + 1 + for _ in util.itertools_filter( + lambda v: v == value, iter(self) + ) + ] + ) + + def extend(self, values): + for v in values: + self.append(v) + + def insert(self, index, value): + self.col[index:index] = [self._create(value)] + + def pop(self, index=-1): + return self.getter(self.col.pop(index)) + + def remove(self, value): + for i, val in enumerate(self): + if val == value: + del self.col[i] + return + raise ValueError("value not in list") + + def reverse(self): + """Not supported, use reversed(mylist)""" + + raise NotImplementedError + + def sort(self): + """Not supported, use sorted(mylist)""" + + raise NotImplementedError + + def clear(self): + del self.col[0 : len(self.col)] + + def __eq__(self, other): + return list(self) == other + + def __ne__(self, other): + return list(self) != other + + def __lt__(self, other): + return list(self) < other + + def __le__(self, other): + return list(self) <= other + + def __gt__(self, other): + return list(self) > other + + def __ge__(self, other): + return list(self) >= other + + def __cmp__(self, other): + return util.cmp(list(self), other) + + def __add__(self, iterable): + try: + other = list(iterable) + except TypeError: + return NotImplemented + return list(self) + other + + def __radd__(self, iterable): + try: + other = list(iterable) + except TypeError: + return NotImplemented + return other + list(self) + + def __mul__(self, n): + if not isinstance(n, int): + return NotImplemented + return list(self) * n + + __rmul__ = __mul__ + + def __iadd__(self, iterable): + self.extend(iterable) + return self + + def __imul__(self, n): + # unlike a regular list *=, proxied __imul__ will generate unique + # backing objects for each copy. *= on proxied lists is a bit of + # a stretch anyhow, and this interpretation of the __imul__ contract + # is more plausibly useful than copying the backing objects. + if not isinstance(n, int): + return NotImplemented + if n == 0: + self.clear() + elif n > 1: + self.extend(list(self) * (n - 1)) + return self + + def index(self, item, *args): + return list(self).index(item, *args) + + def copy(self): + return list(self) + + def __repr__(self): + return repr(list(self)) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +_NotProvided = util.symbol("_NotProvided") + + +class _AssociationDict(_AssociationCollection): + """Generic, converting, dict-to-dict proxy.""" + + def _create(self, key, value): + return self.creator(key, value) + + def _get(self, object_): + return self.getter(object_) + + def _set(self, object_, key, value): + return self.setter(object_, key, value) + + def __getitem__(self, key): + return self._get(self.col[key]) + + def __setitem__(self, key, value): + if key in self.col: + self._set(self.col[key], key, value) + else: + self.col[key] = self._create(key, value) + + def __delitem__(self, key): + del self.col[key] + + def __contains__(self, key): + # testlib.pragma exempt:__hash__ + return key in self.col + + def has_key(self, key): + # testlib.pragma exempt:__hash__ + return key in self.col + + def __iter__(self): + return iter(self.col.keys()) + + def clear(self): + self.col.clear() + + def __eq__(self, other): + return dict(self) == other + + def __ne__(self, other): + return dict(self) != other + + def __lt__(self, other): + return dict(self) < other + + def __le__(self, other): + return dict(self) <= other + + def __gt__(self, other): + return dict(self) > other + + def __ge__(self, other): + return dict(self) >= other + + def __cmp__(self, other): + return util.cmp(dict(self), other) + + def __repr__(self): + return repr(dict(self.items())) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key, default=None): + if key not in self.col: + self.col[key] = self._create(key, default) + return default + else: + return self[key] + + def keys(self): + return self.col.keys() + + if util.py2k: + + def iteritems(self): + return ((key, self._get(self.col[key])) for key in self.col) + + def itervalues(self): + return (self._get(self.col[key]) for key in self.col) + + def iterkeys(self): + return self.col.iterkeys() + + def values(self): + return [self._get(member) for member in self.col.values()] + + def items(self): + return [(k, self._get(self.col[k])) for k in self] + + else: + + def items(self): + return ((key, self._get(self.col[key])) for key in self.col) + + def values(self): + return (self._get(self.col[key]) for key in self.col) + + def pop(self, key, default=_NotProvided): + if default is _NotProvided: + member = self.col.pop(key) + else: + member = self.col.pop(key, default) + return self._get(member) + + def popitem(self): + item = self.col.popitem() + return (item[0], self._get(item[1])) + + def update(self, *a, **kw): + if len(a) > 1: + raise TypeError( + "update expected at most 1 arguments, got %i" % len(a) + ) + elif len(a) == 1: + seq_or_map = a[0] + # discern dict from sequence - took the advice from + # https://www.voidspace.org.uk/python/articles/duck_typing.shtml + # still not perfect :( + if hasattr(seq_or_map, "keys"): + for item in seq_or_map: + self[item] = seq_or_map[item] + else: + try: + for k, v in seq_or_map: + self[k] = v + except ValueError as err: + util.raise_( + ValueError( + "dictionary update sequence " + "requires 2-element tuples" + ), + replace_context=err, + ) + + for key, value in kw: + self[key] = value + + def _bulk_replace(self, assoc_proxy, values): + existing = set(self) + constants = existing.intersection(values or ()) + additions = set(values or ()).difference(constants) + removals = existing.difference(constants) + + for key, member in values.items() or (): + if key in additions: + self[key] = member + elif key in constants: + self[key] = member + + for key in removals: + del self[key] + + def copy(self): + return dict(self.items()) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): + func.__doc__ = getattr(dict, func_name).__doc__ + del func_name, func + + +class _AssociationSet(_AssociationCollection): + """Generic, converting, set-to-set proxy.""" + + def _create(self, value): + return self.creator(value) + + def _get(self, object_): + return self.getter(object_) + + def __len__(self): + return len(self.col) + + def __bool__(self): + if self.col: + return True + else: + return False + + __nonzero__ = __bool__ + + def __contains__(self, value): + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __iter__(self): + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or just use + the underlying collection directly from its property on the parent. + + """ + for member in self.col: + yield self._get(member) + return + + def add(self, value): + if value not in self: + self.col.add(self._create(value)) + + # for discard and remove, choosing a more expensive check strategy rather + # than call self.creator() + def discard(self, value): + for member in self.col: + if self._get(member) == value: + self.col.discard(member) + break + + def remove(self, value): + for member in self.col: + if self._get(member) == value: + self.col.discard(member) + return + raise KeyError(value) + + def pop(self): + if not self.col: + raise KeyError("pop from an empty set") + member = self.col.pop() + return self._get(member) + + def update(self, other): + for value in other: + self.add(value) + + def _bulk_replace(self, assoc_proxy, values): + existing = set(self) + constants = existing.intersection(values or ()) + additions = set(values or ()).difference(constants) + removals = existing.difference(constants) + + appender = self.add + remover = self.remove + + for member in values or (): + if member in additions: + appender(member) + elif member in constants: + appender(member) + + for member in removals: + remover(member) + + def __ior__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + for value in other: + self.add(value) + return self + + def _set(self): + return set(iter(self)) + + def union(self, other): + return set(self).union(other) + + __or__ = union + + def difference(self, other): + return set(self).difference(other) + + __sub__ = difference + + def difference_update(self, other): + for value in other: + self.discard(value) + + def __isub__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + for value in other: + self.discard(value) + return self + + def intersection(self, other): + return set(self).intersection(other) + + __and__ = intersection + + def intersection_update(self, other): + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __iand__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def symmetric_difference(self, other): + return set(self).symmetric_difference(other) + + __xor__ = symmetric_difference + + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __ixor__(self, other): + if not collections._set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def issubset(self, other): + return set(self).issubset(other) + + def issuperset(self, other): + return set(self).issuperset(other) + + def clear(self): + self.col.clear() + + def copy(self): + return set(self) + + def __eq__(self, other): + return set(self) == other + + def __ne__(self, other): + return set(self) != other + + def __lt__(self, other): + return set(self) < other + + def __le__(self, other): + return set(self) <= other + + def __gt__(self, other): + return set(self) > other + + def __ge__(self, other): + return set(self) >= other + + def __repr__(self): + return repr(set(self)) + + def __hash__(self): + raise TypeError("%s objects are unhashable" % type(self).__name__) + + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): + func.__doc__ = getattr(set, func_name).__doc__ + del func_name, func diff --git a/lib/sqlalchemy/ext/asyncio/__init__.py b/lib/sqlalchemy/ext/asyncio/__init__.py new file mode 100644 index 0000000..15b2cb0 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/__init__.py @@ -0,0 +1,22 @@ +# ext/asyncio/__init__.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .engine import async_engine_from_config +from .engine import AsyncConnection +from .engine import AsyncEngine +from .engine import AsyncTransaction +from .engine import create_async_engine +from .events import AsyncConnectionEvents +from .events import AsyncSessionEvents +from .result import AsyncMappingResult +from .result import AsyncResult +from .result import AsyncScalarResult +from .scoping import async_scoped_session +from .session import async_object_session +from .session import async_session +from .session import AsyncSession +from .session import AsyncSessionTransaction diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py new file mode 100644 index 0000000..3f77f55 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -0,0 +1,89 @@ +import abc +import functools +import weakref + +from . import exc as async_exc + + +class ReversibleProxy: + # weakref.ref(async proxy object) -> weakref.ref(sync proxied object) + _proxy_objects = {} + __slots__ = ("__weakref__",) + + def _assign_proxied(self, target): + if target is not None: + target_ref = weakref.ref(target, ReversibleProxy._target_gced) + proxy_ref = weakref.ref( + self, + functools.partial(ReversibleProxy._target_gced, target_ref), + ) + ReversibleProxy._proxy_objects[target_ref] = proxy_ref + + return target + + @classmethod + def _target_gced(cls, ref, proxy_ref=None): + cls._proxy_objects.pop(ref, None) + + @classmethod + def _regenerate_proxy_for_target(cls, target): + raise NotImplementedError() + + @classmethod + def _retrieve_proxy_for_target(cls, target, regenerate=True): + try: + proxy_ref = cls._proxy_objects[weakref.ref(target)] + except KeyError: + pass + else: + proxy = proxy_ref() + if proxy is not None: + return proxy + + if regenerate: + return cls._regenerate_proxy_for_target(target) + else: + return None + + +class StartableContext(abc.ABC): + __slots__ = () + + @abc.abstractmethod + async def start(self, is_ctxmanager=False): + pass + + def __await__(self): + return self.start().__await__() + + async def __aenter__(self): + return await self.start(is_ctxmanager=True) + + @abc.abstractmethod + async def __aexit__(self, type_, value, traceback): + pass + + def _raise_for_not_started(self): + raise async_exc.AsyncContextNotStarted( + "%s context has not been started and object has not been awaited." + % (self.__class__.__name__) + ) + + +class ProxyComparable(ReversibleProxy): + __slots__ = () + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self._proxied == other._proxied + ) + + def __ne__(self, other): + return ( + not isinstance(other, self.__class__) + or self._proxied != other._proxied + ) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py new file mode 100644 index 0000000..4fbe4f7 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -0,0 +1,828 @@ +# ext/asyncio/engine.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +import asyncio + +from . import exc as async_exc +from .base import ProxyComparable +from .base import StartableContext +from .result import _ensure_sync_result +from .result import AsyncResult +from ... import exc +from ... import inspection +from ... import util +from ...engine import create_engine as _create_engine +from ...engine.base import NestedTransaction +from ...future import Connection +from ...future import Engine +from ...util.concurrency import greenlet_spawn + + +def create_async_engine(*arg, **kw): + """Create a new async engine instance. + + Arguments passed to :func:`_asyncio.create_async_engine` are mostly + identical to those passed to the :func:`_sa.create_engine` function. + The specified dialect must be an asyncio-compatible dialect + such as :ref:`dialect-postgresql-asyncpg`. + + .. versionadded:: 1.4 + + """ + + if kw.get("server_side_cursors", False): + raise async_exc.AsyncMethodRequired( + "Can't set server_side_cursors for async engine globally; " + "use the connection.stream() method for an async " + "streaming result set" + ) + kw["future"] = True + sync_engine = _create_engine(*arg, **kw) + return AsyncEngine(sync_engine) + + +def async_engine_from_config(configuration, prefix="sqlalchemy.", **kwargs): + """Create a new AsyncEngine instance using a configuration dictionary. + + This function is analogous to the :func:`_sa.engine_from_config` function + in SQLAlchemy Core, except that the requested dialect must be an + asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`. + The argument signature of the function is identical to that + of :func:`_sa.engine_from_config`. + + .. versionadded:: 1.4.29 + + """ + options = { + key[len(prefix) :]: value + for key, value in configuration.items() + if key.startswith(prefix) + } + options["_coerce_config"] = True + options.update(kwargs) + url = options.pop("url") + return create_async_engine(url, **options) + + +class AsyncConnectable: + __slots__ = "_slots_dispatch", "__weakref__" + + +@util.create_proxy_methods( + Connection, + ":class:`_future.Connection`", + ":class:`_asyncio.AsyncConnection`", + classmethods=[], + methods=[], + attributes=[ + "closed", + "invalidated", + "dialect", + "default_isolation_level", + ], +) +class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable): + """An asyncio proxy for a :class:`_engine.Connection`. + + :class:`_asyncio.AsyncConnection` is acquired using the + :meth:`_asyncio.AsyncEngine.connect` + method of :class:`_asyncio.AsyncEngine`:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + async with engine.connect() as conn: + result = await conn.execute(select(table)) + + .. versionadded:: 1.4 + + """ # noqa + + # AsyncConnection is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncConnection that matches this one given only the + # "sync" elements. + __slots__ = ( + "engine", + "sync_engine", + "sync_connection", + ) + + def __init__(self, async_engine, sync_connection=None): + self.engine = async_engine + self.sync_engine = async_engine.sync_engine + self.sync_connection = self._assign_proxied(sync_connection) + + sync_connection: Connection + """Reference to the sync-style :class:`_engine.Connection` this + :class:`_asyncio.AsyncConnection` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + """ + + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncConnection` is associated with via its underlying + :class:`_engine.Connection`. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + """ + + @classmethod + def _regenerate_proxy_for_target(cls, target): + return AsyncConnection( + AsyncEngine._retrieve_proxy_for_target(target.engine), target + ) + + async def start(self, is_ctxmanager=False): + """Start this :class:`_asyncio.AsyncConnection` object's context + outside of using a Python ``with:`` block. + + """ + if self.sync_connection: + raise exc.InvalidRequestError("connection is already started") + self.sync_connection = self._assign_proxied( + await (greenlet_spawn(self.sync_engine.connect)) + ) + return self + + @property + def connection(self): + """Not implemented for async; call + :meth:`_asyncio.AsyncConnection.get_raw_connection`. + """ + raise exc.InvalidRequestError( + "AsyncConnection.connection accessor is not implemented as the " + "attribute may need to reconnect on an invalidated connection. " + "Use the get_raw_connection() method." + ) + + async def get_raw_connection(self): + """Return the pooled DBAPI-level connection in use by this + :class:`_asyncio.AsyncConnection`. + + This is a SQLAlchemy connection-pool proxied connection + which then has the attribute + :attr:`_pool._ConnectionFairy.driver_connection` that refers to the + actual driver connection. Its + :attr:`_pool._ConnectionFairy.dbapi_connection` refers instead + to an :class:`_engine.AdaptedConnection` instance that + adapts the driver connection to the DBAPI protocol. + + """ + conn = self._sync_connection() + + return await greenlet_spawn(getattr, conn, "connection") + + @property + def _proxied(self): + return self.sync_connection + + @property + def info(self): + """Return the :attr:`_engine.Connection.info` dictionary of the + underlying :class:`_engine.Connection`. + + This dictionary is freely writable for user-defined state to be + associated with the database connection. + + This attribute is only available if the :class:`.AsyncConnection` is + currently connected. If the :attr:`.AsyncConnection.closed` attribute + is ``True``, then accessing this attribute will raise + :class:`.ResourceClosedError`. + + .. versionadded:: 1.4.0b2 + + """ + return self.sync_connection.info + + def _sync_connection(self): + if not self.sync_connection: + self._raise_for_not_started() + return self.sync_connection + + def begin(self): + """Begin a transaction prior to autobegin occurring.""" + self._sync_connection() + return AsyncTransaction(self) + + def begin_nested(self): + """Begin a nested transaction and return a transaction handle.""" + self._sync_connection() + return AsyncTransaction(self, nested=True) + + async def invalidate(self, exception=None): + """Invalidate the underlying DBAPI connection associated with + this :class:`_engine.Connection`. + + See the method :meth:`_engine.Connection.invalidate` for full + detail on this method. + + """ + + conn = self._sync_connection() + return await greenlet_spawn(conn.invalidate, exception=exception) + + async def get_isolation_level(self): + conn = self._sync_connection() + return await greenlet_spawn(conn.get_isolation_level) + + async def set_isolation_level(self): + conn = self._sync_connection() + return await greenlet_spawn(conn.get_isolation_level) + + def in_transaction(self): + """Return True if a transaction is in progress. + + .. versionadded:: 1.4.0b2 + + """ + + conn = self._sync_connection() + + return conn.in_transaction() + + def in_nested_transaction(self): + """Return True if a transaction is in progress. + + .. versionadded:: 1.4.0b2 + + """ + conn = self._sync_connection() + + return conn.in_nested_transaction() + + def get_transaction(self): + """Return an :class:`.AsyncTransaction` representing the current + transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_transaction` method to get the current + :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + conn = self._sync_connection() + + trans = conn.get_transaction() + if trans is not None: + return AsyncTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self): + """Return an :class:`.AsyncTransaction` representing the current + nested (savepoint) transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_nested_transaction` method to get the + current :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + conn = self._sync_connection() + + trans = conn.get_nested_transaction() + if trans is not None: + return AsyncTransaction._retrieve_proxy_for_target(trans) + else: + return None + + async def execution_options(self, **opt): + r"""Set non-SQL options for the connection which take effect + during execution. + + This returns this :class:`_asyncio.AsyncConnection` object with + the new options added. + + See :meth:`_future.Connection.execution_options` for full details + on this method. + + """ + + conn = self._sync_connection() + c2 = await greenlet_spawn(conn.execution_options, **opt) + assert c2 is conn + return self + + async def commit(self): + """Commit the transaction that is currently in progress. + + This method commits the current transaction if one has been started. + If no transaction was started, the method has no effect, assuming + the connection is in a non-invalidated state. + + A transaction is begun on a :class:`_future.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_future.Connection.begin` method is called. + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.commit) + + async def rollback(self): + """Roll back the transaction that is currently in progress. + + This method rolls back the current transaction if one has been started. + If no transaction was started, the method has no effect. If a + transaction was started and the connection is in an invalidated state, + the transaction is cleared using this method. + + A transaction is begun on a :class:`_future.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_future.Connection.begin` method is called. + + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.rollback) + + async def close(self): + """Close this :class:`_asyncio.AsyncConnection`. + + This has the effect of also rolling back the transaction if one + is in place. + + """ + conn = self._sync_connection() + await greenlet_spawn(conn.close) + + async def exec_driver_sql( + self, + statement, + parameters=None, + execution_options=util.EMPTY_DICT, + ): + r"""Executes a driver-level SQL string and return buffered + :class:`_engine.Result`. + + """ + + conn = self._sync_connection() + + result = await greenlet_spawn( + conn.exec_driver_sql, + statement, + parameters, + execution_options, + _require_await=True, + ) + + return await _ensure_sync_result(result, self.exec_driver_sql) + + async def stream( + self, + statement, + parameters=None, + execution_options=util.EMPTY_DICT, + ): + """Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object.""" + + conn = self._sync_connection() + + result = await greenlet_spawn( + conn._execute_20, + statement, + parameters, + util.EMPTY_DICT.merge_with( + execution_options, {"stream_results": True} + ), + _require_await=True, + ) + if not result.context._is_server_side: + # TODO: real exception here + assert False, "server side result expected" + return AsyncResult(result) + + async def execute( + self, + statement, + parameters=None, + execution_options=util.EMPTY_DICT, + ): + r"""Executes a SQL statement construct and return a buffered + :class:`_engine.Result`. + + :param object: The statement to be executed. This is always + an object that is in both the :class:`_expression.ClauseElement` and + :class:`_expression.Executable` hierarchies, including: + + * :class:`_expression.Select` + * :class:`_expression.Insert`, :class:`_expression.Update`, + :class:`_expression.Delete` + * :class:`_expression.TextClause` and + :class:`_expression.TextualSelect` + * :class:`_schema.DDL` and objects which inherit from + :class:`_schema.DDLElement` + + :param parameters: parameters which will be bound into the statement. + This may be either a dictionary of parameter names to values, + or a mutable sequence (e.g. a list) of dictionaries. When a + list of dictionaries is passed, the underlying statement execution + will make use of the DBAPI ``cursor.executemany()`` method. + When a single dictionary is passed, the DBAPI ``cursor.execute()`` + method will be used. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_future.Connection.execution_options`. + + :return: a :class:`_engine.Result` object. + + """ + conn = self._sync_connection() + + result = await greenlet_spawn( + conn._execute_20, + statement, + parameters, + execution_options, + _require_await=True, + ) + return await _ensure_sync_result(result, self.execute) + + async def scalar( + self, + statement, + parameters=None, + execution_options=util.EMPTY_DICT, + ): + r"""Executes a SQL statement construct and returns a scalar object. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalar` method after invoking the + :meth:`_future.Connection.execute` method. Parameters are equivalent. + + :return: a scalar Python value representing the first column of the + first row returned. + + """ + result = await self.execute(statement, parameters, execution_options) + return result.scalar() + + async def scalars( + self, + statement, + parameters=None, + execution_options=util.EMPTY_DICT, + ): + r"""Executes a SQL statement construct and returns a scalar objects. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalars` method after invoking the + :meth:`_future.Connection.execute` method. Parameters are equivalent. + + :return: a :class:`_engine.ScalarResult` object. + + .. versionadded:: 1.4.24 + + """ + result = await self.execute(statement, parameters, execution_options) + return result.scalars() + + async def stream_scalars( + self, + statement, + parameters=None, + execution_options=util.EMPTY_DICT, + ): + r"""Executes a SQL statement and returns a streaming scalar result + object. + + This method is shorthand for invoking the + :meth:`_engine.AsyncResult.scalars` method after invoking the + :meth:`_future.Connection.stream` method. Parameters are equivalent. + + :return: an :class:`_asyncio.AsyncScalarResult` object. + + .. versionadded:: 1.4.24 + + """ + result = await self.stream(statement, parameters, execution_options) + return result.scalars() + + async def run_sync(self, fn, *arg, **kw): + """Invoke the given sync callable passing self as the first argument. + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + E.g.:: + + with async_engine.begin() as conn: + await conn.run_sync(metadata.create_all) + + .. note:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :ref:`session_run_sync` + """ + + conn = self._sync_connection() + + return await greenlet_spawn(fn, conn, *arg, **kw) + + def __await__(self): + return self.start().__await__() + + async def __aexit__(self, type_, value, traceback): + await asyncio.shield(self.close()) + + +@util.create_proxy_methods( + Engine, + ":class:`_future.Engine`", + ":class:`_asyncio.AsyncEngine`", + classmethods=[], + methods=[ + "clear_compiled_cache", + "update_execution_options", + "get_execution_options", + ], + attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], +) +class AsyncEngine(ProxyComparable, AsyncConnectable): + """An asyncio proxy for a :class:`_engine.Engine`. + + :class:`_asyncio.AsyncEngine` is acquired using the + :func:`_asyncio.create_async_engine` function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + .. versionadded:: 1.4 + + """ # noqa + + # AsyncEngine is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncEngine that matches this one given only the + # "sync" elements. + __slots__ = ("sync_engine", "_proxied") + + _connection_cls = AsyncConnection + + _option_cls: type + + class _trans_ctx(StartableContext): + def __init__(self, conn): + self.conn = conn + + async def start(self, is_ctxmanager=False): + await self.conn.start(is_ctxmanager=is_ctxmanager) + self.transaction = self.conn.begin() + await self.transaction.__aenter__() + + return self.conn + + async def __aexit__(self, type_, value, traceback): + async def go(): + await self.transaction.__aexit__(type_, value, traceback) + await self.conn.close() + + await asyncio.shield(go()) + + def __init__(self, sync_engine): + if not sync_engine.dialect.is_async: + raise exc.InvalidRequestError( + "The asyncio extension requires an async driver to be used. " + f"The loaded {sync_engine.dialect.driver!r} is not async." + ) + self.sync_engine = self._proxied = self._assign_proxied(sync_engine) + + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncEngine` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + """ + + @classmethod + def _regenerate_proxy_for_target(cls, target): + return AsyncEngine(target) + + def begin(self): + """Return a context manager which when entered will deliver an + :class:`_asyncio.AsyncConnection` with an + :class:`_asyncio.AsyncTransaction` established. + + E.g.:: + + async with async_engine.begin() as conn: + await conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) + await conn.execute(text("my_special_procedure(5)")) + + + """ + conn = self.connect() + return self._trans_ctx(conn) + + def connect(self): + """Return an :class:`_asyncio.AsyncConnection` object. + + The :class:`_asyncio.AsyncConnection` will procure a database + connection from the underlying connection pool when it is entered + as an async context manager:: + + async with async_engine.connect() as conn: + result = await conn.execute(select(user_table)) + + The :class:`_asyncio.AsyncConnection` may also be started outside of a + context manager by invoking its :meth:`_asyncio.AsyncConnection.start` + method. + + """ + + return self._connection_cls(self) + + async def raw_connection(self): + """Return a "raw" DBAPI connection from the connection pool. + + .. seealso:: + + :ref:`dbapi_connections` + + """ + return await greenlet_spawn(self.sync_engine.raw_connection) + + def execution_options(self, **opt): + """Return a new :class:`_asyncio.AsyncEngine` that will provide + :class:`_asyncio.AsyncConnection` objects with the given execution + options. + + Proxied from :meth:`_future.Engine.execution_options`. See that + method for details. + + """ + + return AsyncEngine(self.sync_engine.execution_options(**opt)) + + async def dispose(self): + """Dispose of the connection pool used by this + :class:`_asyncio.AsyncEngine`. + + This will close all connection pool connections that are + **currently checked in**. See the documentation for the underlying + :meth:`_future.Engine.dispose` method for further notes. + + .. seealso:: + + :meth:`_future.Engine.dispose` + + """ + + await greenlet_spawn(self.sync_engine.dispose) + + +class AsyncTransaction(ProxyComparable, StartableContext): + """An asyncio proxy for a :class:`_engine.Transaction`.""" + + __slots__ = ("connection", "sync_transaction", "nested") + + def __init__(self, connection, nested=False): + self.connection = connection # AsyncConnection + self.sync_transaction = None # sqlalchemy.engine.Transaction + self.nested = nested + + @classmethod + def _regenerate_proxy_for_target(cls, target): + sync_connection = target.connection + sync_transaction = target + nested = isinstance(target, NestedTransaction) + + async_connection = AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + assert async_connection is not None + + obj = cls.__new__(cls) + obj.connection = async_connection + obj.sync_transaction = obj._assign_proxied(sync_transaction) + obj.nested = nested + return obj + + def _sync_transaction(self): + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + @property + def _proxied(self): + return self.sync_transaction + + @property + def is_valid(self): + return self._sync_transaction().is_valid + + @property + def is_active(self): + return self._sync_transaction().is_active + + async def close(self): + """Close this :class:`.Transaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + await greenlet_spawn(self._sync_transaction().close) + + async def rollback(self): + """Roll back this :class:`.Transaction`.""" + await greenlet_spawn(self._sync_transaction().rollback) + + async def commit(self): + """Commit this :class:`.Transaction`.""" + + await greenlet_spawn(self._sync_transaction().commit) + + async def start(self, is_ctxmanager=False): + """Start this :class:`_asyncio.AsyncTransaction` object's context + outside of using a Python ``with:`` block. + + """ + + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.connection._sync_connection().begin_nested + if self.nested + else self.connection._sync_connection().begin + ) + ) + if is_ctxmanager: + self.sync_transaction.__enter__() + return self + + async def __aexit__(self, type_, value, traceback): + await greenlet_spawn( + self._sync_transaction().__exit__, type_, value, traceback + ) + + +def _get_sync_engine_or_connection(async_engine): + if isinstance(async_engine, AsyncConnection): + return async_engine.sync_connection + + try: + return async_engine.sync_engine + except AttributeError as e: + raise exc.ArgumentError( + "AsyncEngine expected, got %r" % async_engine + ) from e + + +@inspection._inspects(AsyncConnection) +def _no_insp_for_async_conn_yet(subject): + raise exc.NoInspectionAvailable( + "Inspection on an AsyncConnection is currently not supported. " + "Please use ``run_sync`` to pass a callable where it's possible " + "to call ``inspect`` on the passed connection.", + code="xd3s", + ) + + +@inspection._inspects(AsyncEngine) +def _no_insp_for_async_engine_xyet(subject): + raise exc.NoInspectionAvailable( + "Inspection on an AsyncEngine is currently not supported. " + "Please obtain a connection then use ``conn.run_sync`` to pass a " + "callable where it's possible to call ``inspect`` on the " + "passed connection.", + code="xd3s", + ) diff --git a/lib/sqlalchemy/ext/asyncio/events.py b/lib/sqlalchemy/ext/asyncio/events.py new file mode 100644 index 0000000..c5d5e01 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/events.py @@ -0,0 +1,44 @@ +# ext/asyncio/events.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .engine import AsyncConnectable +from .session import AsyncSession +from ...engine import events as engine_event +from ...orm import events as orm_event + + +class AsyncConnectionEvents(engine_event.ConnectionEvents): + _target_class_doc = "SomeEngine" + _dispatch_target = AsyncConnectable + + @classmethod + def _no_async_engine_events(cls): + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes." + ) + + @classmethod + def _listen(cls, event_key, retval=False): + cls._no_async_engine_events() + + +class AsyncSessionEvents(orm_event.SessionEvents): + _target_class_doc = "SomeSession" + _dispatch_target = AsyncSession + + @classmethod + def _no_async_engine_events(cls): + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) + + @classmethod + def _listen(cls, event_key, retval=False): + cls._no_async_engine_events() diff --git a/lib/sqlalchemy/ext/asyncio/exc.py b/lib/sqlalchemy/ext/asyncio/exc.py new file mode 100644 index 0000000..cf0d9a8 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/exc.py @@ -0,0 +1,21 @@ +# ext/asyncio/exc.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from ... import exc + + +class AsyncMethodRequired(exc.InvalidRequestError): + """an API can't be used because its result would not be + compatible with async""" + + +class AsyncContextNotStarted(exc.InvalidRequestError): + """a startable context manager has not been started.""" + + +class AsyncContextAlreadyStarted(exc.InvalidRequestError): + """a startable context manager is already started.""" diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py new file mode 100644 index 0000000..a77b6a8 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -0,0 +1,671 @@ +# ext/asyncio/result.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import operator + +from . import exc as async_exc +from ...engine.result import _NO_ROW +from ...engine.result import FilterResult +from ...engine.result import FrozenResult +from ...engine.result import MergedResult +from ...sql.base import _generative +from ...util.concurrency import greenlet_spawn + + +class AsyncCommon(FilterResult): + async def close(self): + """Close this result.""" + + await greenlet_spawn(self._real_result.close) + + +class AsyncResult(AsyncCommon): + """An asyncio wrapper around a :class:`_result.Result` object. + + The :class:`_asyncio.AsyncResult` only applies to statement executions that + use a server-side cursor. It is returned only from the + :meth:`_asyncio.AsyncConnection.stream` and + :meth:`_asyncio.AsyncSession.stream` methods. + + .. note:: As is the case with :class:`_engine.Result`, this object is + used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`, + which can yield instances of ORM mapped objects either individually or + within tuple-like rows. Note that these result objects do not + deduplicate instances or rows automatically as is the case with the + legacy :class:`_orm.Query` object. For in-Python de-duplication of + instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier + method. + + .. versionadded:: 1.4 + + """ + + def __init__(self, real_result): + self._real_result = real_result + + self._metadata = real_result._metadata + self._unique_filter_state = real_result._unique_filter_state + + # BaseCursorResult pre-generates the "_row_getter". Use that + # if available rather than building a second one + if "_row_getter" in real_result.__dict__: + self._set_memoized_attribute( + "_row_getter", real_result.__dict__["_row_getter"] + ) + + def keys(self): + """Return the :meth:`_engine.Result.keys` collection from the + underlying :class:`_engine.Result`. + + """ + return self._metadata.keys + + @_generative + def unique(self, strategy=None): + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncResult`. + + Refer to :meth:`_engine.Result.unique` in the synchronous + SQLAlchemy API for a complete behavioral description. + + + """ + self._unique_filter_state = (set(), strategy) + + def columns(self, *col_expressions): + r"""Establish the columns that should be returned in each row. + + Refer to :meth:`_engine.Result.columns` in the synchronous + SQLAlchemy API for a complete behavioral description. + + + """ + return self._column_slices(col_expressions) + + async def partitions(self, size=None): + """Iterate through sub-lists of rows of the size given. + + An async iterator is returned:: + + async def scroll_results(connection): + result = await connection.stream(select(users_table)) + + async for partition in result.partitions(100): + print("list of rows: %s" % partition) + + .. seealso:: + + :meth:`_engine.Result.partitions` + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchone(self): + """Fetch one row. + + When all rows are exhausted, returns None. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch the first row of a result only, use the + :meth:`_engine.Result.first` method. To iterate through all + rows, iterate the :class:`_engine.Result` object directly. + + :return: a :class:`.Row` object if no filters are applied, or None + if no rows remain. + + """ + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany(self, size=None): + """Fetch many rows. + + When all rows are exhausted, returns an empty list. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch rows in groups, use the + :meth:`._asyncio.AsyncResult.partitions` method. + + :return: a list of :class:`.Row` objects. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.partitions` + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self): + """Return all rows in a list. + + Closes the result set after invocation. Subsequent invocations + will return an empty list. + + :return: a list of :class:`.Row` objects. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self): + return self + + async def __anext__(self): + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self): + """Fetch the first row or None if no row is present. + + Closes the result set and discards remaining rows. + + .. note:: This method returns one **row**, e.g. tuple, by default. To + return exactly one single scalar value, that is, the first column of + the first row, use the :meth:`_asyncio.AsyncResult.scalar` method, + or combine :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.first`. + + :return: a :class:`.Row` object, or None + if no rows remain. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.scalar` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self): + """Return at most one result or raise an exception. + + Returns ``None`` if the result has no rows. + Raises :class:`.MultipleResultsFound` + if multiple rows are returned. + + .. versionadded:: 1.4 + + :return: The first :class:`.Row` or None if no row is available. + + :raises: :class:`.MultipleResultsFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def scalar_one(self): + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, True, True) + + async def scalar_one_or_none(self): + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one_or_none`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, False, True) + + async def one(self): + """Return exactly one row or raise an exception. + + Raises :class:`.NoResultFound` if the result returns no + rows, or :class:`.MultipleResultsFound` if multiple rows + would be returned. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_asyncio.AsyncResult.scalar_one` method, or combine + :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.one`. + + .. versionadded:: 1.4 + + :return: The first :class:`.Row`. + + :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalar_one` + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + async def scalar(self): + """Fetch the first column of the first row, and close the result set. + + Returns None if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or None if no rows remain. + + """ + return await greenlet_spawn(self._only_one_row, False, False, True) + + async def freeze(self): + """Return a callable object that will produce copies of this + :class:`_asyncio.AsyncResult` when invoked. + + The callable object returned is an instance of + :class:`_engine.FrozenResult`. + + This is used for result set caching. The method must be called + on the result when it has been unconsumed, and calling the method + will consume the result fully. When the :class:`_engine.FrozenResult` + is retrieved from a cache, it can be called any number of times where + it will produce a new :class:`_engine.Result` object each time + against its stored set of rows. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + """ + + return await greenlet_spawn(FrozenResult, self) + + def merge(self, *others): + """Merge this :class:`_asyncio.AsyncResult` with other compatible + result objects. + + The object returned is an instance of :class:`_engine.MergedResult`, + which will be composed of iterators from the given result + objects. + + The new result will use the metadata from this result object. + The subsequent result objects must be against an identical + set of result / cursor metadata, otherwise the behavior is + undefined. + + """ + return MergedResult(self._metadata, (self,) + others) + + def scalars(self, index=0): + """Return an :class:`_asyncio.AsyncScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + Refer to :meth:`_result.Result.scalars` in the synchronous + SQLAlchemy API for a complete behavioral description. + + :param index: integer or row key indicating the column to be fetched + from each row, defaults to ``0`` indicating the first column. + + :return: a new :class:`_asyncio.AsyncScalarResult` filtering object + referring to this :class:`_asyncio.AsyncResult` object. + + """ + return AsyncScalarResult(self._real_result, index) + + def mappings(self): + """Apply a mappings filter to returned rows, returning an instance of + :class:`_asyncio.AsyncMappingResult`. + + When this filter is applied, fetching rows will return + :class:`.RowMapping` objects instead of :class:`.Row` objects. + + Refer to :meth:`_result.Result.mappings` in the synchronous + SQLAlchemy API for a complete behavioral description. + + :return: a new :class:`_asyncio.AsyncMappingResult` filtering object + referring to the underlying :class:`_result.Result` object. + + """ + + return AsyncMappingResult(self._real_result) + + +class AsyncScalarResult(AsyncCommon): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values + rather than :class:`_row.Row` values. + + The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.scalars` method. + + Refer to the :class:`_result.ScalarResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + _generate_rows = False + + def __init__(self, real_result, index): + self._real_result = real_result + + if real_result._source_supports_scalars: + self._metadata = real_result._metadata + self._post_creational_filter = None + else: + self._metadata = real_result._metadata._reduce([index]) + self._post_creational_filter = operator.itemgetter(0) + + self._unique_filter_state = real_result._unique_filter_state + + def unique(self, strategy=None): + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncScalarResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + async def partitions(self, size=None): + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self): + """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchmany(self, size=None): + """Fetch many objects. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self): + """Return all scalar values in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._allrows) + + def __aiter__(self): + return self + + async def __anext__(self): + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self): + """Fetch the first object or None if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self): + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self): + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + scalar values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +class AsyncMappingResult(AsyncCommon): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary + values rather than :class:`_engine.Row` values. + + The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.mappings` method. + + Refer to the :class:`_result.MappingResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + _generate_rows = True + + _post_creational_filter = operator.attrgetter("_mapping") + + def __init__(self, result): + self._real_result = result + self._unique_filter_state = result._unique_filter_state + self._metadata = result._metadata + if result._source_supports_scalars: + self._metadata = self._metadata._reduce([0]) + + def keys(self): + """Return an iterable view which yields the string keys that would + be represented by each :class:`.Row`. + + The view also can be tested for key containment using the Python + ``in`` operator, which will test both for the string keys represented + in the view, as well as for alternate keys such as column objects. + + .. versionchanged:: 1.4 a key view object is returned rather than a + plain list. + + + """ + return self._metadata.keys + + def unique(self, strategy=None): + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncMappingResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions): + r"""Establish the columns that should be returned in each row.""" + return self._column_slices(col_expressions) + + async def partitions(self, size=None): + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self): + """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchone(self): + """Fetch one object. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany(self, size=None): + """Fetch many objects. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self): + """Return all scalar values in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self): + return self + + async def __anext__(self): + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self): + """Fetch the first object or None if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self): + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self): + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + mapping values, rather than :class:`_result.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +async def _ensure_sync_result(result, calling_method): + if not result._is_cursor: + cursor_result = getattr(result, "raw", None) + else: + cursor_result = result + if cursor_result and cursor_result.context._is_server_side: + await greenlet_spawn(cursor_result.close) + raise async_exc.AsyncMethodRequired( + "Can't use the %s.%s() method with a " + "server-side cursor. " + "Use the %s.stream() method for an async " + "streaming result set." + % ( + calling_method.__self__.__class__.__name__, + calling_method.__name__, + calling_method.__self__.__class__.__name__, + ) + ) + return result diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py new file mode 100644 index 0000000..8eca8c5 --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -0,0 +1,107 @@ +# ext/asyncio/scoping.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .session import AsyncSession +from ...orm.scoping import ScopedSessionMixin +from ...util import create_proxy_methods +from ...util import ScopedRegistry + + +@create_proxy_methods( + AsyncSession, + ":class:`_asyncio.AsyncSession`", + ":class:`_asyncio.scoping.async_scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get", + "get_bind", + "is_modified", + "invalidate", + "merge", + "refresh", + "rollback", + "scalar", + "scalars", + "stream", + "stream_scalars", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class async_scoped_session(ScopedSessionMixin): + """Provides scoped management of :class:`.AsyncSession` objects. + + See the section :ref:`asyncio_scoped_session` for usage details. + + .. versionadded:: 1.4.19 + + + """ + + _support_async = True + + def __init__(self, session_factory, scopefunc): + """Construct a new :class:`_asyncio.async_scoped_session`. + + :param session_factory: a factory to create new :class:`_asyncio.AsyncSession` + instances. This is usually, but not necessarily, an instance + of :class:`_orm.sessionmaker` which itself was passed the + :class:`_asyncio.AsyncSession` to its :paramref:`_orm.sessionmaker.class_` + parameter:: + + async_session_factory = sessionmaker(some_async_engine, class_= AsyncSession) + AsyncSession = async_scoped_session(async_session_factory, scopefunc=current_task) + + :param scopefunc: function which defines + the current scope. A function such as ``asyncio.current_task`` + may be useful here. + + """ # noqa: E501 + + self.session_factory = session_factory + self.registry = ScopedRegistry(session_factory, scopefunc) + + @property + def _proxied(self): + return self.registry() + + async def remove(self): + """Dispose of the current :class:`.AsyncSession`, if present. + + Different from scoped_session's remove method, this method would use + await to wait for the close method of AsyncSession. + + """ + + if self.registry.has(): + await self.registry().close() + self.registry.clear() diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py new file mode 100644 index 0000000..378cbcb --- /dev/null +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -0,0 +1,759 @@ +# ext/asyncio/session.py +# Copyright (C) 2020-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import asyncio + +from . import engine +from . import result as _result +from .base import ReversibleProxy +from .base import StartableContext +from .result import _ensure_sync_result +from ... import util +from ...orm import object_session +from ...orm import Session +from ...orm import state as _instance_state +from ...util.concurrency import greenlet_spawn + +_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) +_STREAM_OPTIONS = util.immutabledict({"stream_results": True}) + + +@util.create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_asyncio.AsyncSession`", + classmethods=["object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "expire", + "expire_all", + "expunge", + "expunge_all", + "is_modified", + "in_transaction", + "in_nested_transaction", + ], + attributes=[ + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class AsyncSession(ReversibleProxy): + """Asyncio version of :class:`_orm.Session`. + + The :class:`_asyncio.AsyncSession` is a proxy for a traditional + :class:`_orm.Session` instance. + + .. versionadded:: 1.4 + + To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session` + implementations, see the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. + + + """ + + _is_asyncio = True + + dispatch = None + + def __init__(self, bind=None, binds=None, sync_session_class=None, **kw): + r"""Construct a new :class:`_asyncio.AsyncSession`. + + All parameters other than ``sync_session_class`` are passed to the + ``sync_session_class`` callable directly to instantiate a new + :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for + parameter documentation. + + :param sync_session_class: + A :class:`_orm.Session` subclass or other callable which will be used + to construct the :class:`_orm.Session` which will be proxied. This + parameter may be used to provide custom :class:`_orm.Session` + subclasses. Defaults to the + :attr:`_asyncio.AsyncSession.sync_session_class` class-level + attribute. + + .. versionadded:: 1.4.24 + + """ + kw["future"] = True + if bind: + self.bind = bind + bind = engine._get_sync_engine_or_connection(bind) + + if binds: + self.binds = binds + binds = { + key: engine._get_sync_engine_or_connection(b) + for key, b in binds.items() + } + + if sync_session_class: + self.sync_session_class = sync_session_class + + self.sync_session = self._proxied = self._assign_proxied( + self.sync_session_class(bind=bind, binds=binds, **kw) + ) + + sync_session_class = Session + """The class or callable that provides the + underlying :class:`_orm.Session` instance for a particular + :class:`_asyncio.AsyncSession`. + + At the class level, this attribute is the default value for the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom + subclasses of :class:`_asyncio.AsyncSession` can override this. + + At the instance level, this attribute indicates the current class or + callable that was used to provide the :class:`_orm.Session` instance for + this :class:`_asyncio.AsyncSession` instance. + + .. versionadded:: 1.4.24 + + """ + + sync_session: Session + """Reference to the underlying :class:`_orm.Session` this + :class:`_asyncio.AsyncSession` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + async def refresh( + self, instance, attribute_names=None, with_for_update=None + ): + """Expire and refresh the attributes on the given instance. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + This is the async version of the :meth:`_orm.Session.refresh` method. + See that method for a complete description of all options. + + .. seealso:: + + :meth:`_orm.Session.refresh` - main documentation for refresh + + """ + + return await greenlet_spawn( + self.sync_session.refresh, + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + async def run_sync(self, fn, *arg, **kw): + """Invoke the given sync callable passing sync self as the first + argument. + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + E.g.:: + + with AsyncSession(async_engine) as session: + await session.run_sync(some_business_method) + + .. note:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :ref:`session_run_sync` + """ + + return await greenlet_spawn(fn, self.sync_session, *arg, **kw) + + async def execute( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return a buffered + :class:`_engine.Result` object. + + .. seealso:: + + :meth:`_orm.Session.execute` - main documentation for execute + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return await _ensure_sync_result(result, self.execute) + + async def scalar( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return a scalar result. + + .. seealso:: + + :meth:`_orm.Session.scalar` - main documentation for scalar + + """ + + result = await self.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return result.scalar() + + async def scalars( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return scalar results. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version + + """ + + result = await self.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return result.scalars() + + async def get( + self, + entity, + ident, + options=None, + populate_existing=False, + with_for_update=None, + identity_token=None, + ): + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + """ + return await greenlet_spawn( + self.sync_session.get, + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + ) + + async def stream( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object. + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _STREAM_OPTIONS + ) + else: + execution_options = _STREAM_OPTIONS + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return _result.AsyncResult(result) + + async def stream_scalars( + self, + statement, + params=None, + execution_options=util.EMPTY_DICT, + bind_arguments=None, + **kw + ): + """Execute a statement and return a stream of scalar results. + + :return: an :class:`_asyncio.AsyncScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.scalars` - non streaming version + + """ + + result = await self.stream( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return result.scalars() + + async def delete(self, instance): + """Mark an instance as deleted. + + The database delete operation occurs upon ``flush()``. + + As this operation may need to cascade along unloaded relationships, + it is awaitable to allow for those queries to take place. + + .. seealso:: + + :meth:`_orm.Session.delete` - main documentation for delete + + """ + return await greenlet_spawn(self.sync_session.delete, instance) + + async def merge(self, instance, load=True, options=None): + """Copy the state of a given instance into a corresponding instance + within this :class:`_asyncio.AsyncSession`. + + .. seealso:: + + :meth:`_orm.Session.merge` - main documentation for merge + + """ + return await greenlet_spawn( + self.sync_session.merge, instance, load=load, options=options + ) + + async def flush(self, objects=None): + """Flush all the object changes to the database. + + .. seealso:: + + :meth:`_orm.Session.flush` - main documentation for flush + + """ + await greenlet_spawn(self.sync_session.flush, objects=objects) + + def get_transaction(self): + """Return the current root transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + trans = self.sync_session.get_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self): + """Return the current nested transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + + trans = self.sync_session.get_nested_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_bind(self, mapper=None, clause=None, bind=None, **kw): + """Return a "bind" to which the synchronous proxied :class:`_orm.Session` + is bound. + + Unlike the :meth:`_orm.Session.get_bind` method, this method is + currently **not** used by this :class:`.AsyncSession` in any way + in order to resolve engines for requests. + + .. note:: + + This method proxies directly to the :meth:`_orm.Session.get_bind` + method, however is currently **not** useful as an override target, + in contrast to that of the :meth:`_orm.Session.get_bind` method. + The example below illustrates how to implement custom + :meth:`_orm.Session.get_bind` schemes that work with + :class:`.AsyncSession` and :class:`.AsyncEngine`. + + The pattern introduced at :ref:`session_custom_partitioning` + illustrates how to apply a custom bind-lookup scheme to a + :class:`_orm.Session` given a set of :class:`_engine.Engine` objects. + To apply a corresponding :meth:`_orm.Session.get_bind` implementation + for use with a :class:`.AsyncSession` and :class:`.AsyncEngine` + objects, continue to subclass :class:`_orm.Session` and apply it to + :class:`.AsyncSession` using + :paramref:`.AsyncSession.sync_session_class`. The inner method must + continue to return :class:`_engine.Engine` instances, which can be + acquired from a :class:`_asyncio.AsyncEngine` using the + :attr:`_asyncio.AsyncEngine.sync_engine` attribute:: + + # using example from "Custom Vertical Partitioning" + + + import random + + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.orm import Session, sessionmaker + + # construct async engines w/ async drivers + engines = { + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + } + + class RoutingSession(Session): + def get_bind(self, mapper=None, clause=None, **kw): + # within get_bind(), return sync engines + if mapper and issubclass(mapper.class_, MyOtherClass): + return engines['other'].sync_engine + elif self._flushing or isinstance(clause, (Update, Delete)): + return engines['leader'].sync_engine + else: + return engines[ + random.choice(['follower1','follower2']) + ].sync_engine + + # apply to AsyncSession using sync_session_class + AsyncSessionMaker = sessionmaker( + class_=AsyncSession, + sync_session_class=RoutingSession + ) + + The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, + implicitly non-blocking context in the same manner as ORM event hooks + and functions that are invoked via :meth:`.AsyncSession.run_sync`, so + routines that wish to run SQL commands inside of + :meth:`_orm.Session.get_bind` can continue to do so using + blocking-style code, which will be translated to implicitly async calls + at the point of invoking IO on the database drivers. + + """ # noqa: E501 + + return self.sync_session.get_bind( + mapper=mapper, clause=clause, bind=bind, **kw + ) + + async def connection(self, **kw): + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. + + This method may also be used to establish execution options for the + database connection used by the current transaction. + + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. + + .. seealso:: + + :meth:`_orm.Session.connection` - main documentation for + "connection" + + """ + + sync_connection = await greenlet_spawn( + self.sync_session.connection, **kw + ) + return engine.AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + + def begin(self, **kw): + """Return an :class:`_asyncio.AsyncSessionTransaction` object. + + The underlying :class:`_orm.Session` will perform the + "begin" action when the :class:`_asyncio.AsyncSessionTransaction` + object is entered:: + + async with async_session.begin(): + # .. ORM transaction is begun + + Note that database IO will not normally occur when the session-level + transaction is begun, as database transactions begin on an + on-demand basis. However, the begin block is async to accommodate + for a :meth:`_orm.SessionEvents.after_transaction_create` + event hook that may perform IO. + + For a general description of ORM begin, see + :meth:`_orm.Session.begin`. + + """ + + return AsyncSessionTransaction(self) + + def begin_nested(self, **kw): + """Return an :class:`_asyncio.AsyncSessionTransaction` object + which will begin a "nested" transaction, e.g. SAVEPOINT. + + Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`. + + For a general description of ORM begin nested, see + :meth:`_orm.Session.begin_nested`. + + """ + + return AsyncSessionTransaction(self, nested=True) + + async def rollback(self): + """Rollback the current transaction in progress.""" + return await greenlet_spawn(self.sync_session.rollback) + + async def commit(self): + """Commit the current transaction in progress.""" + return await greenlet_spawn(self.sync_session.commit) + + async def close(self): + """Close out the transactional resources and ORM objects used by this + :class:`_asyncio.AsyncSession`. + + This expunges all ORM objects associated with this + :class:`_asyncio.AsyncSession`, ends any transaction in progress and + :term:`releases` any :class:`_asyncio.AsyncConnection` objects which + this :class:`_asyncio.AsyncSession` itself has checked out from + associated :class:`_asyncio.AsyncEngine` objects. The operation then + leaves the :class:`_asyncio.AsyncSession` in a state which it may be + used again. + + .. tip:: + + The :meth:`_asyncio.AsyncSession.close` method **does not prevent + the Session from being used again**. The + :class:`_asyncio.AsyncSession` itself does not actually have a + distinct "closed" state; it merely means the + :class:`_asyncio.AsyncSession` will release all database + connections and ORM objects. + + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` + + """ + await greenlet_spawn(self.sync_session.close) + + async def invalidate(self): + """Close this Session, using connection invalidation. + + For a complete description, see :meth:`_orm.Session.invalidate`. + """ + return await greenlet_spawn(self.sync_session.invalidate) + + @classmethod + async def close_all(self): + """Close all :class:`_asyncio.AsyncSession` sessions.""" + return await greenlet_spawn(self.sync_session.close_all) + + async def __aenter__(self): + return self + + async def __aexit__(self, type_, value, traceback): + await asyncio.shield(self.close()) + + def _maker_context_manager(self): + # no @contextlib.asynccontextmanager until python3.7, gr + return _AsyncSessionContextManager(self) + + +class _AsyncSessionContextManager: + def __init__(self, async_session): + self.async_session = async_session + + async def __aenter__(self): + self.trans = self.async_session.begin() + await self.trans.__aenter__() + return self.async_session + + async def __aexit__(self, type_, value, traceback): + async def go(): + await self.trans.__aexit__(type_, value, traceback) + await self.async_session.__aexit__(type_, value, traceback) + + await asyncio.shield(go()) + + +class AsyncSessionTransaction(ReversibleProxy, StartableContext): + """A wrapper for the ORM :class:`_orm.SessionTransaction` object. + + This object is provided so that a transaction-holding object + for the :meth:`_asyncio.AsyncSession.begin` may be returned. + + The object supports both explicit calls to + :meth:`_asyncio.AsyncSessionTransaction.commit` and + :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an + async context manager. + + + .. versionadded:: 1.4 + + """ + + __slots__ = ("session", "sync_transaction", "nested") + + def __init__(self, session, nested=False): + self.session = session + self.nested = nested + self.sync_transaction = None + + @property + def is_active(self): + return ( + self._sync_transaction() is not None + and self._sync_transaction().is_active + ) + + def _sync_transaction(self): + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + async def rollback(self): + """Roll back this :class:`_asyncio.AsyncTransaction`.""" + await greenlet_spawn(self._sync_transaction().rollback) + + async def commit(self): + """Commit this :class:`_asyncio.AsyncTransaction`.""" + + await greenlet_spawn(self._sync_transaction().commit) + + async def start(self, is_ctxmanager=False): + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.session.sync_session.begin_nested + if self.nested + else self.session.sync_session.begin + ) + ) + if is_ctxmanager: + self.sync_transaction.__enter__() + return self + + async def __aexit__(self, type_, value, traceback): + await greenlet_spawn( + self._sync_transaction().__exit__, type_, value, traceback + ) + + +def async_object_session(instance): + """Return the :class:`_asyncio.AsyncSession` to which the given instance + belongs. + + This function makes use of the sync-API function + :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which + refers to the given instance, and from there links it to the original + :class:`_asyncio.AsyncSession`. + + If the :class:`_asyncio.AsyncSession` has been garbage collected, the + return value is ``None``. + + This functionality is also available from the + :attr:`_orm.InstanceState.async_session` accessor. + + :param instance: an ORM mapped instance + :return: an :class:`_asyncio.AsyncSession` object, or ``None``. + + .. versionadded:: 1.4.18 + + """ + + session = object_session(instance) + if session is not None: + return async_session(session) + else: + return None + + +def async_session(session): + """Return the :class:`_asyncio.AsyncSession` which is proxying the given + :class:`_orm.Session` object, if any. + + :param session: a :class:`_orm.Session` instance. + :return: a :class:`_asyncio.AsyncSession` instance, or ``None``. + + .. versionadded:: 1.4.18 + + """ + return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) + + +_instance_state._async_provider = async_session diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py new file mode 100644 index 0000000..a5d7267 --- /dev/null +++ b/lib/sqlalchemy/ext/automap.py @@ -0,0 +1,1234 @@ +# ext/automap.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system +which automatically generates mapped classes and relationships from a database +schema, typically though not necessarily one which is reflected. + +It is hoped that the :class:`.AutomapBase` system provides a quick +and modernized solution to the problem that the very famous +`SQLSoup <https://sqlsoup.readthedocs.io/en/latest/>`_ +also tries to solve, that of generating a quick and rudimentary object +model from an existing database on the fly. By addressing the issue strictly +at the mapper configuration level, and integrating fully with existing +Declarative class techniques, :class:`.AutomapBase` seeks to provide +a well-integrated approach to the issue of expediently auto-generating ad-hoc +mappings. + +.. tip:: The :ref:`automap_toplevel` extension is geared towards a + "zero declaration" approach, where a complete ORM model including classes + and pre-named relationships can be generated on the fly from a database + schema. For applications that still want to use explicit class declarations + including explicit relationship definitions in conjunction with reflection + of tables, the :class:`.DeferredReflection` class, described at + :ref:`orm_declarative_reflected_deferred_reflection`, is a better choice. + + + +Basic Use +========= + +The simplest usage is to reflect an existing database into a new model. +We create a new :class:`.AutomapBase` class in a similar manner as to how +we create a declarative base class, using :func:`.automap_base`. +We then call :meth:`.AutomapBase.prepare` on the resulting base class, +asking it to reflect the schema and produce mappings:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy.orm import Session + from sqlalchemy import create_engine + + Base = automap_base() + + # engine, suppose it has two tables 'user' and 'address' set up + engine = create_engine("sqlite:///mydatabase.db") + + # reflect the tables + Base.prepare(autoload_with=engine) + + # mapped classes are now created with names by default + # matching that of the table name. + User = Base.classes.user + Address = Base.classes.address + + session = Session(engine) + + # rudimentary relationships are produced + session.add(Address(email_address="foo@bar.com", user=User(name="foo"))) + session.commit() + + # collection-based relationships are by default named + # "<classname>_collection" + print (u1.address_collection) + +Above, calling :meth:`.AutomapBase.prepare` while passing along the +:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the +:meth:`_schema.MetaData.reflect` +method will be called on this declarative base +classes' :class:`_schema.MetaData` collection; then, each **viable** +:class:`_schema.Table` within the :class:`_schema.MetaData` +will get a new mapped class +generated automatically. The :class:`_schema.ForeignKeyConstraint` +objects which +link the various tables together will be used to produce new, bidirectional +:func:`_orm.relationship` objects between classes. +The classes and relationships +follow along a default naming scheme that we can customize. At this point, +our basic mapping consisting of related ``User`` and ``Address`` classes is +ready to use in the traditional way. + +.. note:: By **viable**, we mean that for a table to be mapped, it must + specify a primary key. Additionally, if the table is detected as being + a pure association table between two other tables, it will not be directly + mapped and will instead be configured as a many-to-many table between + the mappings for the two referring tables. + +Generating Mappings from an Existing MetaData +============================================= + +We can pass a pre-declared :class:`_schema.MetaData` object to +:func:`.automap_base`. +This object can be constructed in any way, including programmatically, from +a serialized file, or from itself being reflected using +:meth:`_schema.MetaData.reflect`. +Below we illustrate a combination of reflection and +explicit table declaration:: + + from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey + from sqlalchemy.ext.automap import automap_base + engine = create_engine("sqlite:///mydatabase.db") + + # produce our own MetaData object + metadata = MetaData() + + # we can reflect it ourselves from a database, using options + # such as 'only' to limit what tables we look at... + metadata.reflect(engine, only=['user', 'address']) + + # ... or just define our own Table objects with it (or combine both) + Table('user_order', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', ForeignKey('user.id')) + ) + + # we can then produce a set of mappings from this MetaData. + Base = automap_base(metadata=metadata) + + # calling prepare() just sets up mapped classes and relationships. + Base.prepare() + + # mapped classes are ready + User, Address, Order = Base.classes.user, Base.classes.address,\ + Base.classes.user_order + +Specifying Classes Explicitly +============================= + +.. tip:: If explicit classes are expected to be prominent in an application, + consider using :class:`.DeferredReflection` instead. + +The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined +explicitly, in a way similar to that of the :class:`.DeferredReflection` class. +Classes that extend from :class:`.AutomapBase` act like regular declarative +classes, but are not immediately mapped after their construction, and are +instead mapped when we call :meth:`.AutomapBase.prepare`. The +:meth:`.AutomapBase.prepare` method will make use of the classes we've +established based on the table name we use. If our schema contains tables +``user`` and ``address``, we can define one or both of the classes to be used:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + # pre-declare User for the 'user' table + class User(Base): + __tablename__ = 'user' + + # override schema elements like Columns + user_name = Column('name', String) + + # override relationships too, if desired. + # we must use the same name that automap would use for the + # relationship, and also must refer to the class name that automap will + # generate for "address" + address_collection = relationship("address", collection_class=set) + + # reflect + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(autoload_with=engine) + + # we still have Address generated from the tablename "address", + # but User is the same as Base.classes.User now + + Address = Base.classes.address + + u1 = session.query(User).first() + print (u1.address_collection) + + # the backref is still there: + a1 = session.query(Address).first() + print (a1.user) + +Above, one of the more intricate details is that we illustrated overriding +one of the :func:`_orm.relationship` objects that automap would have created. +To do this, we needed to make sure the names match up with what automap +would normally generate, in that the relationship name would be +``User.address_collection`` and the name of the class referred to, from +automap's perspective, is called ``address``, even though we are referring to +it as ``Address`` within our usage of this class. + +Overriding Naming Schemes +========================= + +:mod:`.sqlalchemy.ext.automap` is tasked with producing mapped classes and +relationship names based on a schema, which means it has decision points in how +these names are determined. These three decision points are provided using +functions which can be passed to the :meth:`.AutomapBase.prepare` method, and +are known as :func:`.classname_for_table`, +:func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship`. Any or all of these +functions are provided as in the example below, where we use a "camel case" +scheme for class names and a "pluralizer" for collection names using the +`Inflect <https://pypi.org/project/inflect>`_ package:: + + import re + import inflect + + def camelize_classname(base, tablename, table): + "Produce a 'camelized' class name, e.g. " + "'words_and_underscores' -> 'WordsAndUnderscores'" + + return str(tablename[0].upper() + \ + re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:])) + + _pluralizer = inflect.engine() + def pluralize_collection(base, local_cls, referred_cls, constraint): + "Produce an 'uncamelized', 'pluralized' class name, e.g. " + "'SomeTerm' -> 'some_terms'" + + referred_name = referred_cls.__name__ + uncamelized = re.sub(r'[A-Z]', + lambda m: "_%s" % m.group(0).lower(), + referred_name)[1:] + pluralized = _pluralizer.plural(uncamelized) + return pluralized + + from sqlalchemy.ext.automap import automap_base + + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + + Base.prepare(autoload_with=engine, + classname_for_table=camelize_classname, + name_for_collection_relationship=pluralize_collection + ) + +From the above mapping, we would now have classes ``User`` and ``Address``, +where the collection from ``User`` to ``Address`` is called +``User.addresses``:: + + User, Address = Base.classes.User, Base.classes.Address + + u1 = User(addresses=[Address(email="foo@bar.com")]) + +Relationship Detection +====================== + +The vast majority of what automap accomplishes is the generation of +:func:`_orm.relationship` structures based on foreign keys. The mechanism +by which this works for many-to-one and one-to-many relationships is as +follows: + +1. A given :class:`_schema.Table`, known to be mapped to a particular class, + is examined for :class:`_schema.ForeignKeyConstraint` objects. + +2. From each :class:`_schema.ForeignKeyConstraint`, the remote + :class:`_schema.Table` + object present is matched up to the class to which it is to be mapped, + if any, else it is skipped. + +3. As the :class:`_schema.ForeignKeyConstraint` + we are examining corresponds to a + reference from the immediate mapped class, the relationship will be set up + as a many-to-one referring to the referred class; a corresponding + one-to-many backref will be created on the referred class referring + to this class. + +4. If any of the columns that are part of the + :class:`_schema.ForeignKeyConstraint` + are not nullable (e.g. ``nullable=False``), a + :paramref:`_orm.relationship.cascade` keyword argument + of ``all, delete-orphan`` will be added to the keyword arguments to + be passed to the relationship or backref. If the + :class:`_schema.ForeignKeyConstraint` reports that + :paramref:`_schema.ForeignKeyConstraint.ondelete` + is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable + set of columns, the option :paramref:`_orm.relationship.passive_deletes` + flag is set to ``True`` in the set of relationship keyword arguments. + Note that not all backends support reflection of ON DELETE. + + .. versionadded:: 1.0.0 - automap will detect non-nullable foreign key + constraints when producing a one-to-many relationship and establish + a default cascade of ``all, delete-orphan`` if so; additionally, + if the constraint specifies + :paramref:`_schema.ForeignKeyConstraint.ondelete` + of ``CASCADE`` for non-nullable or ``SET NULL`` for nullable columns, + the ``passive_deletes=True`` option is also added. + +5. The names of the relationships are determined using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + callable functions. It is important to note that the default relationship + naming derives the name from the **the actual class name**. If you've + given a particular class an explicit name by declaring it, or specified an + alternate class naming scheme, that's the name from which the relationship + name will be derived. + +6. The classes are inspected for an existing mapped property matching these + names. If one is detected on one side, but none on the other side, + :class:`.AutomapBase` attempts to create a relationship on the missing side, + then uses the :paramref:`_orm.relationship.back_populates` + parameter in order to + point the new relationship to the other side. + +7. In the usual case where no relationship is on either side, + :meth:`.AutomapBase.prepare` produces a :func:`_orm.relationship` on the + "many-to-one" side and matches it to the other using the + :paramref:`_orm.relationship.backref` parameter. + +8. Production of the :func:`_orm.relationship` and optionally the + :func:`.backref` + is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship` + function, which can be supplied by the end-user in order to augment + the arguments passed to :func:`_orm.relationship` or :func:`.backref` or to + make use of custom implementations of these functions. + +Custom Relationship Arguments +----------------------------- + +The :paramref:`.AutomapBase.prepare.generate_relationship` hook can be used +to add parameters to relationships. For most cases, we can make use of the +existing :func:`.automap.generate_relationship` function to return +the object, after augmenting the given keyword dictionary with our own +arguments. + +Below is an illustration of how to send +:paramref:`_orm.relationship.cascade` and +:paramref:`_orm.relationship.passive_deletes` +options along to all one-to-many relationships:: + + from sqlalchemy.ext.automap import generate_relationship + + def _gen_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw): + if direction is interfaces.ONETOMANY: + kw['cascade'] = 'all, delete-orphan' + kw['passive_deletes'] = True + # make use of the built-in function to actually return + # the result. + return generate_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw) + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(autoload_with=engine, + generate_relationship=_gen_relationship) + +Many-to-Many relationships +-------------------------- + +:mod:`.sqlalchemy.ext.automap` will generate many-to-many relationships, e.g. +those which contain a ``secondary`` argument. The process for producing these +is as follows: + +1. A given :class:`_schema.Table` is examined for + :class:`_schema.ForeignKeyConstraint` + objects, before any mapped class has been assigned to it. + +2. If the table contains two and exactly two + :class:`_schema.ForeignKeyConstraint` + objects, and all columns within this table are members of these two + :class:`_schema.ForeignKeyConstraint` objects, the table is assumed to be a + "secondary" table, and will **not be mapped directly**. + +3. The two (or one, for self-referential) external tables to which the + :class:`_schema.Table` + refers to are matched to the classes to which they will be + mapped, if any. + +4. If mapped classes for both sides are located, a many-to-many bi-directional + :func:`_orm.relationship` / :func:`.backref` + pair is created between the two + classes. + +5. The override logic for many-to-many works the same as that of one-to-many/ + many-to-one; the :func:`.generate_relationship` function is called upon + to generate the structures and existing attributes will be maintained. + +Relationships with Inheritance +------------------------------ + +:mod:`.sqlalchemy.ext.automap` will not generate any relationships between +two classes that are in an inheritance relationship. That is, with two +classes given as follows:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on': type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + __mapper_args__ = { + 'polymorphic_identity':'engineer', + } + +The foreign key from ``Engineer`` to ``Employee`` is used not for a +relationship, but to establish joined inheritance between the two classes. + +Note that this means automap will not generate *any* relationships +for foreign keys that link from a subclass to a superclass. If a mapping +has actual relationships from subclass to superclass as well, those +need to be explicit. Below, as we have two separate foreign keys +from ``Engineer`` to ``Employee``, we need to set up both the relationship +we want as well as the ``inherit_condition``, as these are not things +SQLAlchemy can guess:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on':type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + favorite_employee_id = Column(Integer, ForeignKey('employee.id')) + + favorite_employee = relationship(Employee, + foreign_keys=favorite_employee_id) + + __mapper_args__ = { + 'polymorphic_identity':'engineer', + 'inherit_condition': id == Employee.id + } + +Handling Simple Naming Conflicts +-------------------------------- + +In the case of naming conflicts during mapping, override any of +:func:`.classname_for_table`, :func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship` as needed. For example, if +automap is attempting to name a many-to-one relationship the same as an +existing column, an alternate convention can be conditionally selected. Given +a schema: + +.. sourcecode:: sql + + CREATE TABLE table_a ( + id INTEGER PRIMARY KEY + ); + + CREATE TABLE table_b ( + id INTEGER PRIMARY KEY, + table_a INTEGER, + FOREIGN KEY(table_a) REFERENCES table_a(id) + ); + +The above schema will first automap the ``table_a`` table as a class named +``table_a``; it will then automap a relationship onto the class for ``table_b`` +with the same name as this related class, e.g. ``table_a``. This +relationship name conflicts with the mapping column ``table_b.table_a``, +and will emit an error on mapping. + +We can resolve this conflict by using an underscore as follows:: + + def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): + name = referred_cls.__name__.lower() + local_table = local_cls.__table__ + if name in local_table.columns: + newname = name + "_" + warnings.warn( + "Already detected name %s present. using %s" % + (name, newname)) + return newname + return name + + + Base.prepare(autoload_with=engine, + name_for_scalar_relationship=name_for_scalar_relationship) + +Alternatively, we can change the name on the column side. The columns +that are mapped can be modified using the technique described at +:ref:`mapper_column_distinct_names`, by assigning the column explicitly +to a new name:: + + Base = automap_base() + + class TableB(Base): + __tablename__ = 'table_b' + _table_a = Column('table_a', ForeignKey('table_a.id')) + + Base.prepare(autoload_with=engine) + + +Using Automap with Explicit Declarations +======================================== + +As noted previously, automap has no dependency on reflection, and can make +use of any collection of :class:`_schema.Table` objects within a +:class:`_schema.MetaData` +collection. From this, it follows that automap can also be used +generate missing relationships given an otherwise complete model that fully +defines table metadata:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import Column, Integer, String, ForeignKey + + Base = automap_base() + + class User(Base): + __tablename__ = 'user' + + id = Column(Integer, primary_key=True) + name = Column(String) + + class Address(Base): + __tablename__ = 'address' + + id = Column(Integer, primary_key=True) + email = Column(String) + user_id = Column(ForeignKey('user.id')) + + # produce relationships + Base.prepare() + + # mapping is complete, with "address_collection" and + # "user" relationships + a1 = Address(email='u1') + a2 = Address(email='u2') + u1 = User(address_collection=[a1, a2]) + assert a1.user is u1 + +Above, given mostly complete ``User`` and ``Address`` mappings, the +:class:`_schema.ForeignKey` which we defined on ``Address.user_id`` allowed a +bidirectional relationship pair ``Address.user`` and +``User.address_collection`` to be generated on the mapped classes. + +Note that when subclassing :class:`.AutomapBase`, +the :meth:`.AutomapBase.prepare` method is required; if not called, the classes +we've declared are in an un-mapped state. + + +.. _automap_intercepting_columns: + +Intercepting Column Definitions +=============================== + +The :class:`_schema.MetaData` and :class:`_schema.Table` objects support an +event hook :meth:`_events.DDLEvents.column_reflect` that may be used to intercept +the information reflected about a database column before the :class:`_schema.Column` +object is constructed. For example if we wanted to map columns using a +naming convention such as ``"attr_<columnname>"``, the event could +be applied as:: + + @event.listens_for(Base.metadata, "column_reflect") + def column_reflect(inspector, table, column_info): + # set column.key = "attr_<lower_case_name>" + column_info['key'] = "attr_%s" % column_info['name'].lower() + + # run reflection + Base.prepare(autoload_with=engine) + +.. versionadded:: 1.4.0b2 the :meth:`_events.DDLEvents.column_reflect` event + may be applied to a :class:`_schema.MetaData` object. + +.. seealso:: + + :meth:`_events.DDLEvents.column_reflect` + + :ref:`mapper_automated_reflection_schemes` - in the ORM mapping documentation + + +""" # noqa +from .. import util +from ..orm import backref +from ..orm import declarative_base as _declarative_base +from ..orm import exc as orm_exc +from ..orm import interfaces +from ..orm import relationship +from ..orm.decl_base import _DeferredMapperConfig +from ..orm.mapper import _CONFIGURE_MUTEX +from ..schema import ForeignKeyConstraint +from ..sql import and_ + + +def classname_for_table(base, tablename, table): + """Return the class name that should be used, given the name + of a table. + + The default implementation is:: + + return str(tablename) + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.classname_for_table` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param tablename: string name of the :class:`_schema.Table`. + + :param table: the :class:`_schema.Table` object itself. + + :return: a string class name. + + .. note:: + + In Python 2, the string used for the class name **must** be a + non-Unicode object, e.g. a ``str()`` object. The ``.name`` attribute + of :class:`_schema.Table` is typically a Python unicode subclass, + so the + ``str()`` function should be applied to this name, after accounting for + any non-ASCII characters. + + """ + return str(tablename) + + +def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): + """Return the attribute name that should be used to refer from one + class to another, for a scalar object reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + + +def name_for_collection_relationship( + base, local_cls, referred_cls, constraint +): + """Return the attribute name that should be used to refer from one + class to another, for a collection reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + "_collection" + + Alternate implementations + can be specified using the + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + "_collection" + + +def generate_relationship( + base, direction, return_fn, attrname, local_cls, referred_cls, **kw +): + r"""Generate a :func:`_orm.relationship` or :func:`.backref` + on behalf of two + mapped classes. + + An alternate implementation of this function can be specified using the + :paramref:`.AutomapBase.prepare.generate_relationship` parameter. + + The default implementation of this function is as follows:: + + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param direction: indicate the "direction" of the relationship; this will + be one of :data:`.ONETOMANY`, :data:`.MANYTOONE`, :data:`.MANYTOMANY`. + + :param return_fn: the function that is used by default to create the + relationship. This will be either :func:`_orm.relationship` or + :func:`.backref`. The :func:`.backref` function's result will be used to + produce a new :func:`_orm.relationship` in a second step, + so it is critical + that user-defined implementations correctly differentiate between the two + functions, if a custom relationship function is being used. + + :param attrname: the attribute name to which this relationship is being + assigned. If the value of :paramref:`.generate_relationship.return_fn` is + the :func:`.backref` function, then this name is the name that is being + assigned to the backref. + + :param local_cls: the "local" class to which this relationship or backref + will be locally present. + + :param referred_cls: the "referred" class to which the relationship or + backref refers to. + + :param \**kw: all additional keyword arguments are passed along to the + function. + + :return: a :func:`_orm.relationship` or :func:`.backref` construct, + as dictated + by the :paramref:`.generate_relationship.return_fn` parameter. + + """ + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + +class AutomapBase(object): + """Base class for an "automap" schema. + + The :class:`.AutomapBase` class can be compared to the "declarative base" + class that is produced by the :func:`.declarative.declarative_base` + function. In practice, the :class:`.AutomapBase` class is always used + as a mixin along with an actual declarative base. + + A new subclassable :class:`.AutomapBase` is typically instantiated + using the :func:`.automap_base` function. + + .. seealso:: + + :ref:`automap_toplevel` + + """ + + __abstract__ = True + + classes = None + """An instance of :class:`.util.Properties` containing classes. + + This object behaves much like the ``.c`` collection on a table. Classes + are present under the name they were given, e.g.:: + + Base = automap_base() + Base.prepare(autoload_with=some_engine) + + User, Address = Base.classes.User, Base.classes.Address + + """ + + @classmethod + @util.deprecated_params( + engine=( + "2.0", + "The :paramref:`_automap.AutomapBase.prepare.engine` parameter " + "is deprecated and will be removed in a future release. " + "Please use the " + ":paramref:`_automap.AutomapBase.prepare.autoload_with` " + "parameter.", + ), + reflect=( + "2.0", + "The :paramref:`_automap.AutomapBase.prepare.reflect` " + "parameter is deprecated and will be removed in a future " + "release. Reflection is enabled when " + ":paramref:`_automap.AutomapBase.prepare.autoload_with` " + "is passed.", + ), + ) + def prepare( + cls, + autoload_with=None, + engine=None, + reflect=False, + schema=None, + classname_for_table=None, + collection_class=None, + name_for_scalar_relationship=None, + name_for_collection_relationship=None, + generate_relationship=None, + reflection_options=util.EMPTY_DICT, + ): + """Extract mapped classes and relationships from the + :class:`_schema.MetaData` and + perform mappings. + + :param engine: an :class:`_engine.Engine` or + :class:`_engine.Connection` with which + to perform schema reflection, if specified. + If the :paramref:`.AutomapBase.prepare.reflect` argument is False, + this object is not used. + + :param reflect: if True, the :meth:`_schema.MetaData.reflect` + method is called + on the :class:`_schema.MetaData` associated with this + :class:`.AutomapBase`. + The :class:`_engine.Engine` passed via + :paramref:`.AutomapBase.prepare.engine` will be used to perform the + reflection if present; else, the :class:`_schema.MetaData` + should already be + bound to some engine else the operation will fail. + + :param classname_for_table: callable function which will be used to + produce new class names, given a table name. Defaults to + :func:`.classname_for_table`. + + :param name_for_scalar_relationship: callable function which will be + used to produce relationship names for scalar relationships. Defaults + to :func:`.name_for_scalar_relationship`. + + :param name_for_collection_relationship: callable function which will + be used to produce relationship names for collection-oriented + relationships. Defaults to :func:`.name_for_collection_relationship`. + + :param generate_relationship: callable function which will be used to + actually generate :func:`_orm.relationship` and :func:`.backref` + constructs. Defaults to :func:`.generate_relationship`. + + :param collection_class: the Python collection class that will be used + when a new :func:`_orm.relationship` + object is created that represents a + collection. Defaults to ``list``. + + :param schema: When present in conjunction with the + :paramref:`.AutomapBase.prepare.reflect` flag, is passed to + :meth:`_schema.MetaData.reflect` + to indicate the primary schema where tables + should be reflected from. When omitted, the default schema in use + by the database connection is used. + + .. versionadded:: 1.1 + + :param reflection_options: When present, this dictionary of options + will be passed to :meth:`_schema.MetaData.reflect` + to supply general reflection-specific options like ``only`` and/or + dialect-specific options like ``oracle_resolve_synonyms``. + + .. versionadded:: 1.4 + + """ + glbls = globals() + if classname_for_table is None: + classname_for_table = glbls["classname_for_table"] + if name_for_scalar_relationship is None: + name_for_scalar_relationship = glbls[ + "name_for_scalar_relationship" + ] + if name_for_collection_relationship is None: + name_for_collection_relationship = glbls[ + "name_for_collection_relationship" + ] + if generate_relationship is None: + generate_relationship = glbls["generate_relationship"] + if collection_class is None: + collection_class = list + + if autoload_with: + reflect = True + + if engine: + autoload_with = engine + + if reflect: + opts = dict( + schema=schema, + extend_existing=True, + autoload_replace=False, + ) + if reflection_options: + opts.update(reflection_options) + cls.metadata.reflect(autoload_with, **opts) + + with _CONFIGURE_MUTEX: + table_to_map_config = dict( + (m.local_table, m) + for m in _DeferredMapperConfig.classes_for_base( + cls, sort=False + ) + ) + + many_to_many = [] + + for table in cls.metadata.tables.values(): + lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table) + if lcl_m2m is not None: + many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table)) + elif not table.primary_key: + continue + elif table not in table_to_map_config: + mapped_cls = type( + classname_for_table(cls, table.name, table), + (cls,), + {"__table__": table}, + ) + map_config = _DeferredMapperConfig.config_for_cls( + mapped_cls + ) + cls.classes[map_config.cls.__name__] = mapped_cls + table_to_map_config[table] = map_config + + for map_config in table_to_map_config.values(): + _relationships_for_fks( + cls, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) + + for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: + _m2m_relationship( + cls, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) + + for map_config in _DeferredMapperConfig.classes_for_base(cls): + map_config.map() + + _sa_decl_prepare = True + """Indicate that the mapping of classes should be deferred. + + The presence of this attribute name indicates to declarative + that the call to mapper() should not occur immediately; instead, + information about the table and attributes to be mapped are gathered + into an internal structure called _DeferredMapperConfig. These + objects can be collected later using classes_for_base(), additional + mapping decisions can be made, and then the map() method will actually + apply the mapping. + + The only real reason this deferral of the whole + thing is needed is to support primary key columns that aren't reflected + yet when the class is declared; everything else can theoretically be + added to the mapper later. However, the _DeferredMapperConfig is a + nice interface in any case which exists at that not usually exposed point + at which declarative has the class and the Table but hasn't called + mapper() yet. + + """ + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of AutomapBase. " + "Mappings are not produced until the .prepare() " + "method is called on the class hierarchy." + % orm_exc._safe_cls_name(cls), + ) + + +def automap_base(declarative_base=None, **kw): + r"""Produce a declarative automap base. + + This function produces a new base class that is a product of the + :class:`.AutomapBase` class as well a declarative base produced by + :func:`.declarative.declarative_base`. + + All parameters other than ``declarative_base`` are keyword arguments + that are passed directly to the :func:`.declarative.declarative_base` + function. + + :param declarative_base: an existing class produced by + :func:`.declarative.declarative_base`. When this is passed, the function + no longer invokes :func:`.declarative.declarative_base` itself, and all + other keyword arguments are ignored. + + :param \**kw: keyword arguments are passed along to + :func:`.declarative.declarative_base`. + + """ + if declarative_base is None: + Base = _declarative_base(**kw) + else: + Base = declarative_base + + return type( + Base.__name__, + (AutomapBase, Base), + {"__abstract__": True, "classes": util.Properties({})}, + ) + + +def _is_many_to_many(automap_base, table): + fk_constraints = [ + const + for const in table.constraints + if isinstance(const, ForeignKeyConstraint) + ] + if len(fk_constraints) != 2: + return None, None, None + + cols = sum( + [ + [fk.parent for fk in fk_constraint.elements] + for fk_constraint in fk_constraints + ], + [], + ) + + if set(cols) != set(table.c): + return None, None, None + + return ( + fk_constraints[0].elements[0].column.table, + fk_constraints[1].elements[0].column.table, + fk_constraints, + ) + + +def _relationships_for_fks( + automap_base, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, +): + local_table = map_config.local_table + local_cls = map_config.cls # derived from a weakref, may be None + + if local_table is None or local_cls is None: + return + for constraint in local_table.constraints: + if isinstance(constraint, ForeignKeyConstraint): + fks = constraint.elements + referred_table = fks[0].column.table + referred_cfg = table_to_map_config.get(referred_table, None) + if referred_cfg is None: + continue + referred_cls = referred_cfg.cls + + if local_cls is not referred_cls and issubclass( + local_cls, referred_cls + ): + continue + + relationship_name = name_for_scalar_relationship( + automap_base, local_cls, referred_cls, constraint + ) + backref_name = name_for_collection_relationship( + automap_base, referred_cls, local_cls, constraint + ) + + o2m_kws = {} + nullable = False not in {fk.parent.nullable for fk in fks} + if not nullable: + o2m_kws["cascade"] = "all, delete-orphan" + + if ( + constraint.ondelete + and constraint.ondelete.lower() == "cascade" + ): + o2m_kws["passive_deletes"] = True + else: + if ( + constraint.ondelete + and constraint.ondelete.lower() == "set null" + ): + o2m_kws["passive_deletes"] = True + + create_backref = backref_name not in referred_cfg.properties + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship( + automap_base, + interfaces.ONETOMANY, + backref, + backref_name, + referred_cls, + local_cls, + collection_class=collection_class, + **o2m_kws + ) + else: + backref_obj = None + rel = generate_relationship( + automap_base, + interfaces.MANYTOONE, + relationship, + relationship_name, + local_cls, + referred_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + backref=backref_obj, + remote_side=[fk.column for fk in constraint.elements], + ) + if rel is not None: + map_config.properties[relationship_name] = rel + if not create_backref: + referred_cfg.properties[ + backref_name + ].back_populates = relationship_name + elif create_backref: + rel = generate_relationship( + automap_base, + interfaces.ONETOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + back_populates=relationship_name, + collection_class=collection_class, + **o2m_kws + ) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[ + relationship_name + ].back_populates = backref_name + + +def _m2m_relationship( + automap_base, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, +): + + map_config = table_to_map_config.get(lcl_m2m, None) + referred_cfg = table_to_map_config.get(rem_m2m, None) + if map_config is None or referred_cfg is None: + return + + local_cls = map_config.cls + referred_cls = referred_cfg.cls + + relationship_name = name_for_collection_relationship( + automap_base, local_cls, referred_cls, m2m_const[0] + ) + backref_name = name_for_collection_relationship( + automap_base, referred_cls, local_cls, m2m_const[1] + ) + + create_backref = backref_name not in referred_cfg.properties + + if table in table_to_map_config: + overlaps = "__*" + else: + overlaps = None + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + backref, + backref_name, + referred_cls, + local_cls, + collection_class=collection_class, + overlaps=overlaps, + ) + else: + backref_obj = None + + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + relationship_name, + local_cls, + referred_cls, + overlaps=overlaps, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), + backref=backref_obj, + collection_class=collection_class, + ) + if rel is not None: + map_config.properties[relationship_name] = rel + + if not create_backref: + referred_cfg.properties[ + backref_name + ].back_populates = relationship_name + elif create_backref: + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + overlaps=overlaps, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), + back_populates=relationship_name, + collection_class=collection_class, + ) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[ + relationship_name + ].back_populates = backref_name diff --git a/lib/sqlalchemy/ext/baked.py b/lib/sqlalchemy/ext/baked.py new file mode 100644 index 0000000..109e0c0 --- /dev/null +++ b/lib/sqlalchemy/ext/baked.py @@ -0,0 +1,648 @@ +# sqlalchemy/ext/baked.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""Baked query extension. + +Provides a creational pattern for the :class:`.query.Query` object which +allows the fully constructed object, Core select statement, and string +compiled result to be fully cached. + + +""" + +import logging + +from .. import exc as sa_exc +from .. import util +from ..orm import exc as orm_exc +from ..orm import strategy_options +from ..orm.query import Query +from ..orm.session import Session +from ..sql import func +from ..sql import literal_column +from ..sql import util as sql_util +from ..util import collections_abc + + +log = logging.getLogger(__name__) + + +class Bakery(object): + """Callable which returns a :class:`.BakedQuery`. + + This object is returned by the class method + :meth:`.BakedQuery.bakery`. It exists as an object + so that the "cache" can be easily inspected. + + .. versionadded:: 1.2 + + + """ + + __slots__ = "cls", "cache" + + def __init__(self, cls_, cache): + self.cls = cls_ + self.cache = cache + + def __call__(self, initial_fn, *args): + return self.cls(self.cache, initial_fn, args) + + +class BakedQuery(object): + """A builder object for :class:`.query.Query` objects.""" + + __slots__ = "steps", "_bakery", "_cache_key", "_spoiled" + + def __init__(self, bakery, initial_fn, args=()): + self._cache_key = () + self._update_cache_key(initial_fn, args) + self.steps = [initial_fn] + self._spoiled = False + self._bakery = bakery + + @classmethod + def bakery(cls, size=200, _size_alert=None): + """Construct a new bakery. + + :return: an instance of :class:`.Bakery` + + """ + + return Bakery(cls, util.LRUCache(size, size_alert=_size_alert)) + + def _clone(self): + b1 = BakedQuery.__new__(BakedQuery) + b1._cache_key = self._cache_key + b1.steps = list(self.steps) + b1._bakery = self._bakery + b1._spoiled = self._spoiled + return b1 + + def _update_cache_key(self, fn, args=()): + self._cache_key += (fn.__code__,) + args + + def __iadd__(self, other): + if isinstance(other, tuple): + self.add_criteria(*other) + else: + self.add_criteria(other) + return self + + def __add__(self, other): + if isinstance(other, tuple): + return self.with_criteria(*other) + else: + return self.with_criteria(other) + + def add_criteria(self, fn, *args): + """Add a criteria function to this :class:`.BakedQuery`. + + This is equivalent to using the ``+=`` operator to + modify a :class:`.BakedQuery` in-place. + + """ + self._update_cache_key(fn, args) + self.steps.append(fn) + return self + + def with_criteria(self, fn, *args): + """Add a criteria function to a :class:`.BakedQuery` cloned from this + one. + + This is equivalent to using the ``+`` operator to + produce a new :class:`.BakedQuery` with modifications. + + """ + return self._clone().add_criteria(fn, *args) + + def for_session(self, session): + """Return a :class:`_baked.Result` object for this + :class:`.BakedQuery`. + + This is equivalent to calling the :class:`.BakedQuery` as a + Python callable, e.g. ``result = my_baked_query(session)``. + + """ + return Result(self, session) + + def __call__(self, session): + return self.for_session(session) + + def spoil(self, full=False): + """Cancel any query caching that will occur on this BakedQuery object. + + The BakedQuery can continue to be used normally, however additional + creational functions will not be cached; they will be called + on every invocation. + + This is to support the case where a particular step in constructing + a baked query disqualifies the query from being cacheable, such + as a variant that relies upon some uncacheable value. + + :param full: if False, only functions added to this + :class:`.BakedQuery` object subsequent to the spoil step will be + non-cached; the state of the :class:`.BakedQuery` up until + this point will be pulled from the cache. If True, then the + entire :class:`_query.Query` object is built from scratch each + time, with all creational functions being called on each + invocation. + + """ + if not full and not self._spoiled: + _spoil_point = self._clone() + _spoil_point._cache_key += ("_query_only",) + self.steps = [_spoil_point._retrieve_baked_query] + self._spoiled = True + return self + + def _effective_key(self, session): + """Return the key that actually goes into the cache dictionary for + this :class:`.BakedQuery`, taking into account the given + :class:`.Session`. + + This basically means we also will include the session's query_class, + as the actual :class:`_query.Query` object is part of what's cached + and needs to match the type of :class:`_query.Query` that a later + session will want to use. + + """ + return self._cache_key + (session._query_cls,) + + def _with_lazyload_options(self, options, effective_path, cache_path=None): + """Cloning version of _add_lazyload_options.""" + q = self._clone() + q._add_lazyload_options(options, effective_path, cache_path=cache_path) + return q + + def _add_lazyload_options(self, options, effective_path, cache_path=None): + """Used by per-state lazy loaders to add options to the + "lazy load" query from a parent query. + + Creates a cache key based on given load path and query options; + if a repeatable cache key cannot be generated, the query is + "spoiled" so that it won't use caching. + + """ + + key = () + + if not cache_path: + cache_path = effective_path + + for opt in options: + if opt._is_legacy_option or opt._is_compile_state: + ck = opt._generate_cache_key() + if ck is None: + self.spoil(full=True) + else: + assert not ck[1], ( + "loader options with variable bound parameters " + "not supported with baked queries. Please " + "use new-style select() statements for cached " + "ORM queries." + ) + key += ck[0] + + self.add_criteria( + lambda q: q._with_current_path(effective_path).options(*options), + cache_path.path, + key, + ) + + def _retrieve_baked_query(self, session): + query = self._bakery.get(self._effective_key(session), None) + if query is None: + query = self._as_query(session) + self._bakery[self._effective_key(session)] = query.with_session( + None + ) + return query.with_session(session) + + def _bake(self, session): + query = self._as_query(session) + query.session = None + + # in 1.4, this is where before_compile() event is + # invoked + statement = query._statement_20() + + # if the query is not safe to cache, we still do everything as though + # we did cache it, since the receiver of _bake() assumes subqueryload + # context was set up, etc. + # + # note also we want to cache the statement itself because this + # allows the statement itself to hold onto its cache key that is + # used by the Connection, which in itself is more expensive to + # generate than what BakedQuery was able to provide in 1.3 and prior + + if statement._compile_options._bake_ok: + self._bakery[self._effective_key(session)] = ( + query, + statement, + ) + + return query, statement + + def to_query(self, query_or_session): + """Return the :class:`_query.Query` object for use as a subquery. + + This method should be used within the lambda callable being used + to generate a step of an enclosing :class:`.BakedQuery`. The + parameter should normally be the :class:`_query.Query` object that + is passed to the lambda:: + + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) + + main_bq = self.bakery(lambda s: s.query(Address)) + main_bq += lambda q: q.filter( + sub_bq.to_query(q).exists()) + + In the case where the subquery is used in the first callable against + a :class:`.Session`, the :class:`.Session` is also accepted:: + + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) + + main_bq = self.bakery( + lambda s: s.query( + Address.id, sub_bq.to_query(q).scalar_subquery()) + ) + + :param query_or_session: a :class:`_query.Query` object or a class + :class:`.Session` object, that is assumed to be within the context + of an enclosing :class:`.BakedQuery` callable. + + + .. versionadded:: 1.3 + + + """ + + if isinstance(query_or_session, Session): + session = query_or_session + elif isinstance(query_or_session, Query): + session = query_or_session.session + if session is None: + raise sa_exc.ArgumentError( + "Given Query needs to be associated with a Session" + ) + else: + raise TypeError( + "Query or Session object expected, got %r." + % type(query_or_session) + ) + return self._as_query(session) + + def _as_query(self, session): + query = self.steps[0](session) + + for step in self.steps[1:]: + query = step(query) + + return query + + +class Result(object): + """Invokes a :class:`.BakedQuery` against a :class:`.Session`. + + The :class:`_baked.Result` object is where the actual :class:`.query.Query` + object gets created, or retrieved from the cache, + against a target :class:`.Session`, and is then invoked for results. + + """ + + __slots__ = "bq", "session", "_params", "_post_criteria" + + def __init__(self, bq, session): + self.bq = bq + self.session = session + self._params = {} + self._post_criteria = [] + + def params(self, *args, **kw): + """Specify parameters to be replaced into the string SQL statement.""" + + if len(args) == 1: + kw.update(args[0]) + elif len(args) > 0: + raise sa_exc.ArgumentError( + "params() takes zero or one positional argument, " + "which is a dictionary." + ) + self._params.update(kw) + return self + + def _using_post_criteria(self, fns): + if fns: + self._post_criteria.extend(fns) + return self + + def with_post_criteria(self, fn): + """Add a criteria function that will be applied post-cache. + + This adds a function that will be run against the + :class:`_query.Query` object after it is retrieved from the + cache. This currently includes **only** the + :meth:`_query.Query.params` and :meth:`_query.Query.execution_options` + methods. + + .. warning:: :meth:`_baked.Result.with_post_criteria` + functions are applied + to the :class:`_query.Query` + object **after** the query's SQL statement + object has been retrieved from the cache. Only + :meth:`_query.Query.params` and + :meth:`_query.Query.execution_options` + methods should be used. + + + .. versionadded:: 1.2 + + + """ + return self._using_post_criteria([fn]) + + def _as_query(self): + q = self.bq._as_query(self.session).params(self._params) + for fn in self._post_criteria: + q = fn(q) + return q + + def __str__(self): + return str(self._as_query()) + + def __iter__(self): + return self._iter().__iter__() + + def _iter(self): + bq = self.bq + + if not self.session.enable_baked_queries or bq._spoiled: + return self._as_query()._iter() + + query, statement = bq._bakery.get( + bq._effective_key(self.session), (None, None) + ) + if query is None: + query, statement = bq._bake(self.session) + + if self._params: + q = query.params(self._params) + else: + q = query + for fn in self._post_criteria: + q = fn(q) + + params = q._params + execution_options = dict(q._execution_options) + execution_options.update( + { + "_sa_orm_load_options": q.load_options, + "compiled_cache": bq._bakery, + } + ) + + result = self.session.execute( + statement, params, execution_options=execution_options + ) + if result._attributes.get("is_single_entity", False): + result = result.scalars() + + if result._attributes.get("filtered", False): + result = result.unique() + + return result + + def count(self): + """return the 'count'. + + Equivalent to :meth:`_query.Query.count`. + + Note this uses a subquery to ensure an accurate count regardless + of the structure of the original statement. + + .. versionadded:: 1.1.6 + + """ + + col = func.count(literal_column("*")) + bq = self.bq.with_criteria(lambda q: q._from_self(col)) + return bq.for_session(self.session).params(self._params).scalar() + + def scalar(self): + """Return the first element of the first result or None + if no rows present. If multiple rows are returned, + raises MultipleResultsFound. + + Equivalent to :meth:`_query.Query.scalar`. + + .. versionadded:: 1.1.6 + + """ + try: + ret = self.one() + if not isinstance(ret, collections_abc.Sequence): + return ret + return ret[0] + except orm_exc.NoResultFound: + return None + + def first(self): + """Return the first row. + + Equivalent to :meth:`_query.Query.first`. + + """ + + bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) + return ( + bq.for_session(self.session) + .params(self._params) + ._using_post_criteria(self._post_criteria) + ._iter() + .first() + ) + + def one(self): + """Return exactly one result or raise an exception. + + Equivalent to :meth:`_query.Query.one`. + + """ + return self._iter().one() + + def one_or_none(self): + """Return one or zero results, or raise an exception for multiple + rows. + + Equivalent to :meth:`_query.Query.one_or_none`. + + .. versionadded:: 1.0.9 + + """ + return self._iter().one_or_none() + + def all(self): + """Return all rows. + + Equivalent to :meth:`_query.Query.all`. + + """ + return self._iter().all() + + def get(self, ident): + """Retrieve an object based on identity. + + Equivalent to :meth:`_query.Query.get`. + + """ + + query = self.bq.steps[0](self.session) + return query._get_impl(ident, self._load_on_pk_identity) + + def _load_on_pk_identity(self, session, query, primary_key_identity, **kw): + """Load the given primary key identity from the database.""" + + mapper = query._raw_columns[0]._annotations["parententity"] + + _get_clause, _get_params = mapper._get_clause + + def setup(query): + _lcl_get_clause = _get_clause + q = query._clone() + q._get_condition() + q._order_by = None + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in primary_key_identity: + nones = set( + [ + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + ] + ) + _lcl_get_clause = sql_util.adapt_criterion_to_null( + _lcl_get_clause, nones + ) + + # TODO: can mapper._get_clause be pre-adapted? + q._where_criteria = ( + sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}), + ) + + for fn in self._post_criteria: + q = fn(q) + return q + + # cache the query against a key that includes + # which positions in the primary key are NULL + # (remember, we can map to an OUTER JOIN) + bq = self.bq + + # add the clause we got from mapper._get_clause to the cache + # key so that if a race causes multiple calls to _get_clause, + # we've cached on ours + bq = bq._clone() + bq._cache_key += (_get_clause,) + + bq = bq.with_criteria( + setup, tuple(elem is None for elem in primary_key_identity) + ) + + params = dict( + [ + (_get_params[primary_key].key, id_val) + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + ] + ) + + result = list(bq.for_session(self.session).params(**params)) + l = len(result) + if l > 1: + raise orm_exc.MultipleResultsFound() + elif l: + return result[0] + else: + return None + + +@util.deprecated( + "1.2", "Baked lazy loading is now the default implementation." +) +def bake_lazy_loaders(): + """Enable the use of baked queries for all lazyloaders systemwide. + + The "baked" implementation of lazy loading is now the sole implementation + for the base lazy loader; this method has no effect except for a warning. + + """ + pass + + +@util.deprecated( + "1.2", "Baked lazy loading is now the default implementation." +) +def unbake_lazy_loaders(): + """Disable the use of baked queries for all lazyloaders systemwide. + + This method now raises NotImplementedError() as the "baked" implementation + is the only lazy load implementation. The + :paramref:`_orm.relationship.bake_queries` flag may be used to disable + the caching of queries on a per-relationship basis. + + """ + raise NotImplementedError( + "Baked lazy loading is now the default implementation" + ) + + +@strategy_options.loader_option() +def baked_lazyload(loadopt, attr): + """Indicate that the given attribute should be loaded using "lazy" + loading with a "baked" query used in the load. + + """ + return loadopt.set_relationship_strategy(attr, {"lazy": "baked_select"}) + + +@baked_lazyload._add_unbound_fn +@util.deprecated( + "1.2", + "Baked lazy loading is now the default " + "implementation for lazy loading.", +) +def baked_lazyload(*keys): + return strategy_options._UnboundLoad._from_keys( + strategy_options._UnboundLoad.baked_lazyload, keys, False, {} + ) + + +@baked_lazyload._add_unbound_all_fn +@util.deprecated( + "1.2", + "Baked lazy loading is now the default " + "implementation for lazy loading.", +) +def baked_lazyload_all(*keys): + return strategy_options._UnboundLoad._from_keys( + strategy_options._UnboundLoad.baked_lazyload, keys, True, {} + ) + + +baked_lazyload = baked_lazyload._unbound_fn +baked_lazyload_all = baked_lazyload_all._unbound_all_fn + +bakery = BakedQuery.bakery diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py new file mode 100644 index 0000000..76b59ea --- /dev/null +++ b/lib/sqlalchemy/ext/compiler.py @@ -0,0 +1,613 @@ +# ext/compiler.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Provides an API for creation of custom ClauseElements and compilers. + +Synopsis +======== + +Usage involves the creation of one or more +:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or +more callables defining its compilation:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.sql.expression import ColumnClause + + class MyColumn(ColumnClause): + inherit_cache = True + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, +the base expression element for named column objects. The ``compiles`` +decorator registers itself with the ``MyColumn`` class so that it is invoked +when the object is compiled to a string:: + + from sqlalchemy import select + + s = select(MyColumn('x'), MyColumn('y')) + print(str(s)) + +Produces:: + + SELECT [x], [y] + +Dialect-specific compilation rules +================================== + +Compilers can also be made dialect-specific. The appropriate compiler will be +invoked for the dialect in use:: + + from sqlalchemy.schema import DDLElement + + class AlterColumn(DDLElement): + inherit_cache = False + + def __init__(self, column, cmd): + self.column = column + self.cmd = cmd + + @compiles(AlterColumn) + def visit_alter_column(element, compiler, **kw): + return "ALTER COLUMN %s ..." % element.column.name + + @compiles(AlterColumn, 'postgresql') + def visit_alter_column(element, compiler, **kw): + return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, + element.column.name) + +The second ``visit_alter_table`` will be invoked when any ``postgresql`` +dialect is used. + +.. _compilerext_compiling_subelements: + +Compiling sub-elements of a custom expression construct +======================================================= + +The ``compiler`` argument is the +:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object +can be inspected for any information about the in-progress compilation, +including ``compiler.dialect``, ``compiler.statement`` etc. The +:class:`~sqlalchemy.sql.compiler.SQLCompiler` and +:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` +method which can be used for compilation of embedded attributes:: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class InsertFromSelect(Executable, ClauseElement): + inherit_cache = False + + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True, **kw), + compiler.process(element.select, **kw) + ) + + insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5)) + print(insert) + +Produces:: + + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z + FROM mytable WHERE mytable.x > :x_1)" + +.. note:: + + The above ``InsertFromSelect`` construct is only an example, this actual + functionality is already available using the + :meth:`_expression.Insert.from_select` method. + +.. note:: + + The above ``InsertFromSelect`` construct probably wants to have "autocommit" + enabled. See :ref:`enabling_compiled_autocommit` for this step. + +Cross Compiling between SQL and DDL compilers +--------------------------------------------- + +SQL and DDL constructs are each compiled using different base compilers - +``SQLCompiler`` and ``DDLCompiler``. A common need is to access the +compilation rules of SQL expressions from within a DDL expression. The +``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as +below where we generate a CHECK constraint that embeds a SQL expression:: + + @compiles(MyConstraint) + def compile_my_constraint(constraint, ddlcompiler, **kw): + kw['literal_binds'] = True + return "CONSTRAINT %s CHECK (%s)" % ( + constraint.name, + ddlcompiler.sql_compiler.process( + constraint.expression, **kw) + ) + +Above, we add an additional flag to the process step as called by +:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This +indicates that any SQL expression which refers to a :class:`.BindParameter` +object or other "literal" object such as those which refer to strings or +integers should be rendered **in-place**, rather than being referred to as +a bound parameter; when emitting DDL, bound parameters are typically not +supported. + + +.. _enabling_compiled_autocommit: + +Enabling Autocommit on a Construct +================================== + +Recall from the section :ref:`autocommit` that the :class:`_engine.Engine`, +when +asked to execute a construct in the absence of a user-defined transaction, +detects if the given construct represents DML or DDL, that is, a data +modification or data definition statement, which requires (or may require, +in the case of DDL) that the transaction generated by the DBAPI be committed +(recall that DBAPI always has a transaction going on regardless of what +SQLAlchemy does). Checking for this is actually accomplished by checking for +the "autocommit" execution option on the construct. When building a +construct like an INSERT derivation, a new DDL type, or perhaps a stored +procedure that alters data, the "autocommit" option needs to be set in order +for the statement to function with "connectionless" execution +(as described in :ref:`dbengine_implicit`). + +Currently a quick way to do this is to subclass :class:`.Executable`, then +add the "autocommit" flag to the ``_execution_options`` dictionary (note this +is a "frozen" dictionary which supplies a generative ``union()`` method):: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class MyInsertThing(Executable, ClauseElement): + _execution_options = \ + Executable._execution_options.union({'autocommit': True}) + +More succinctly, if the construct is truly similar to an INSERT, UPDATE, or +DELETE, :class:`.UpdateBase` can be used, which already is a subclass +of :class:`.Executable`, :class:`_expression.ClauseElement` and includes the +``autocommit`` flag:: + + from sqlalchemy.sql.expression import UpdateBase + + class MyInsertThing(UpdateBase): + def __init__(self, ...): + ... + + + + +DDL elements that subclass :class:`.DDLElement` already have the +"autocommit" flag turned on. + + + + +Changing the default compilation of existing constructs +======================================================= + +The compiler extension applies just as well to the existing constructs. When +overriding the compilation of a built in SQL construct, the @compiles +decorator is invoked upon the appropriate class (be sure to use the class, +i.e. ``Insert`` or ``Select``, instead of the creation function such +as ``insert()`` or ``select()``). + +Within the new compilation function, to get at the "original" compilation +routine, use the appropriate visit_XXX method - this +because compiler.process() will call upon the overriding routine and cause +an endless loop. Such as, to add "prefix" to all insert statements:: + + from sqlalchemy.sql.expression import Insert + + @compiles(Insert) + def prefix_inserts(insert, compiler, **kw): + return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) + +The above compiler will prefix all INSERT statements with "some prefix" when +compiled. + +.. _type_compilation_extension: + +Changing Compilation of Types +============================= + +``compiler`` works for types, too, such as below where we implement the +MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: + + @compiles(String, 'mssql') + @compiles(VARCHAR, 'mssql') + def compile_varchar(element, compiler, **kw): + if element.length == 'max': + return "VARCHAR('max')" + else: + return compiler.visit_VARCHAR(element, **kw) + + foo = Table('foo', metadata, + Column('data', VARCHAR('max')) + ) + +Subclassing Guidelines +====================== + +A big part of using the compiler extension is subclassing SQLAlchemy +expression constructs. To make this easier, the expression and +schema packages feature a set of "bases" intended for common tasks. +A synopsis is as follows: + +* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root + expression class. Any SQL expression can be derived from this base, and is + probably the best choice for longer constructs such as specialized INSERT + statements. + +* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all + "column-like" elements. Anything that you'd place in the "columns" clause of + a SELECT statement (as well as order by and group by) can derive from this - + the object will automatically have Python "comparison" behavior. + + :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a + ``type`` member which is expression's return type. This can be established + at the instance level in the constructor, or at the class level if its + generally constant:: + + class timestamp(ColumnElement): + type = TIMESTAMP() + inherit_cache = True + +* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a + ``ColumnElement`` and a "from clause" like object, and represents a SQL + function or stored procedure type of call. Since most databases support + statements along the line of "SELECT FROM <some function>" + ``FunctionElement`` adds in the ability to be used in the FROM clause of a + ``select()`` construct:: + + from sqlalchemy.sql.expression import FunctionElement + + class coalesce(FunctionElement): + name = 'coalesce' + inherit_cache = True + + @compiles(coalesce) + def compile(element, compiler, **kw): + return "coalesce(%s)" % compiler.process(element.clauses, **kw) + + @compiles(coalesce, 'oracle') + def compile(element, compiler, **kw): + if len(element.clauses) > 2: + raise TypeError("coalesce only supports two arguments on Oracle") + return "nvl(%s)" % compiler.process(element.clauses, **kw) + +* :class:`.DDLElement` - The root of all DDL expressions, + like CREATE TABLE, ALTER TABLE, etc. Compilation of :class:`.DDLElement` + subclasses is issued by a :class:`.DDLCompiler` instead of a + :class:`.SQLCompiler`. :class:`.DDLElement` can also be used as an event hook + in conjunction with event hooks like :meth:`.DDLEvents.before_create` and + :meth:`.DDLEvents.after_create`, allowing the construct to be invoked + automatically during CREATE TABLE and DROP TABLE sequences. + + .. seealso:: + + :ref:`metadata_ddl_toplevel` - contains examples of associating + :class:`.DDL` objects (which are themselves :class:`.DDLElement` + instances) with :class:`.DDLEvents` event hooks. + +* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which + should be used with any expression class that represents a "standalone" + SQL statement that can be passed directly to an ``execute()`` method. It + is already implicit within ``DDLElement`` and ``FunctionElement``. + +Most of the above constructs also respond to SQL statement caching. A +subclassed construct will want to define the caching behavior for the object, +which usually means setting the flag ``inherit_cache`` to the value of +``False`` or ``True``. See the next section :ref:`compilerext_caching` +for background. + + +.. _compilerext_caching: + +Enabling Caching Support for Custom Constructs +============================================== + +SQLAlchemy as of version 1.4 includes a +:ref:`SQL compilation caching facility <sql_caching>` which will allow +equivalent SQL constructs to cache their stringified form, along with other +structural information used to fetch results from the statement. + +For reasons discussed at :ref:`caching_caveats`, the implementation of this +caching system takes a conservative approach towards including custom SQL +constructs and/or subclasses within the caching system. This includes that +any user-defined SQL constructs, including all the examples for this +extension, will not participate in caching by default unless they positively +assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache` +attribute when set to ``True`` at the class level of a specific subclass +will indicate that instances of this class may be safely cached, using the +cache key generation scheme of the immediate superclass. This applies +for example to the "synopsis" example indicated previously:: + + class MyColumn(ColumnClause): + inherit_cache = True + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, the ``MyColumn`` class does not include any new state that +affects its SQL compilation; the cache key of ``MyColumn`` instances will +make use of that of the ``ColumnClause`` superclass, meaning it will take +into account the class of the object (``MyColumn``), the string name and +datatype of the object:: + + >>> MyColumn("some_name", String())._generate_cache_key() + CacheKey( + key=('0', <class '__main__.MyColumn'>, + 'name', 'some_name', + 'type', (<class 'sqlalchemy.sql.sqltypes.String'>, + ('length', None), ('collation', None)) + ), bindparams=[]) + +For objects that are likely to be **used liberally as components within many +larger statements**, such as :class:`_schema.Column` subclasses and custom SQL +datatypes, it's important that **caching be enabled as much as possible**, as +this may otherwise negatively affect performance. + +An example of an object that **does** contain state which affects its SQL +compilation is the one illustrated at :ref:`compilerext_compiling_subelements`; +this is an "INSERT FROM SELECT" construct that combines together a +:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of +which independently affect the SQL string generation of the construct. For +this class, the example illustrates that it simply does not participate in +caching:: + + class InsertFromSelect(Executable, ClauseElement): + inherit_cache = False + + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True, **kw), + compiler.process(element.select, **kw) + ) + +While it is also possible that the above ``InsertFromSelect`` could be made to +produce a cache key that is composed of that of the :class:`_schema.Table` and +:class:`_sql.Select` components together, the API for this is not at the moment +fully public. However, for an "INSERT FROM SELECT" construct, which is only +used by itself for specific operations, caching is not as critical as in the +previous example. + +For objects that are **used in relative isolation and are generally +standalone**, such as custom :term:`DML` constructs like an "INSERT FROM +SELECT", **caching is generally less critical** as the lack of caching for such +a construct will have only localized implications for that specific operation. + + +Further Examples +================ + +"UTC timestamp" function +------------------------- + +A function that works like "CURRENT_TIMESTAMP" except applies the +appropriate conversions so that the time is in UTC time. Timestamps are best +stored in relational databases as UTC, without time zones. UTC so that your +database doesn't think time has gone backwards in the hour when daylight +savings ends, without timezones because timezones are like character +encodings - they're best applied only at the endpoints of an application +(i.e. convert to UTC upon user input, re-apply desired timezone upon display). + +For PostgreSQL and Microsoft SQL Server:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import DateTime + + class utcnow(expression.FunctionElement): + type = DateTime() + inherit_cache = True + + @compiles(utcnow, 'postgresql') + def pg_utcnow(element, compiler, **kw): + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + + @compiles(utcnow, 'mssql') + def ms_utcnow(element, compiler, **kw): + return "GETUTCDATE()" + +Example usage:: + + from sqlalchemy import ( + Table, Column, Integer, String, DateTime, MetaData + ) + metadata = MetaData() + event = Table("event", metadata, + Column("id", Integer, primary_key=True), + Column("description", String(50), nullable=False), + Column("timestamp", DateTime, server_default=utcnow()) + ) + +"GREATEST" function +------------------- + +The "GREATEST" function is given any number of arguments and returns the one +that is of the highest value - its equivalent to Python's ``max`` +function. A SQL standard version versus a CASE based version which only +accommodates two arguments:: + + from sqlalchemy.sql import expression, case + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import Numeric + + class greatest(expression.FunctionElement): + type = Numeric() + name = 'greatest' + inherit_cache = True + + @compiles(greatest) + def default_greatest(element, compiler, **kw): + return compiler.visit_function(element) + + @compiles(greatest, 'sqlite') + @compiles(greatest, 'mssql') + @compiles(greatest, 'oracle') + def case_greatest(element, compiler, **kw): + arg1, arg2 = list(element.clauses) + return compiler.process(case([(arg1 > arg2, arg1)], else_=arg2), **kw) + +Example usage:: + + Session.query(Account).\ + filter( + greatest( + Account.checking_balance, + Account.savings_balance) > 10000 + ) + +"false" expression +------------------ + +Render a "false" constant expression, rendering as "0" on platforms that +don't have a "false" constant:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + + class sql_false(expression.ColumnElement): + inherit_cache = True + + @compiles(sql_false) + def default_false(element, compiler, **kw): + return "false" + + @compiles(sql_false, 'mssql') + @compiles(sql_false, 'mysql') + @compiles(sql_false, 'oracle') + def int_false(element, compiler, **kw): + return "0" + +Example usage:: + + from sqlalchemy import select, union_all + + exp = union_all( + select(users.c.name, sql_false().label("enrolled")), + select(customers.c.name, customers.c.enrolled) + ) + +""" +from .. import exc +from .. import util +from ..sql import sqltypes + + +def compiles(class_, *specs): + """Register a function as a compiler for a + given :class:`_expression.ClauseElement` type.""" + + def decorate(fn): + # get an existing @compiles handler + existing = class_.__dict__.get("_compiler_dispatcher", None) + + # get the original handler. All ClauseElement classes have one + # of these, but some TypeEngine classes will not. + existing_dispatch = getattr(class_, "_compiler_dispatch", None) + + if not existing: + existing = _dispatcher() + + if existing_dispatch: + + def _wrap_existing_dispatch(element, compiler, **kw): + try: + return existing_dispatch(element, compiler, **kw) + except exc.UnsupportedCompilationError as uce: + util.raise_( + exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), + ), + from_=uce, + ) + + existing.specs["default"] = _wrap_existing_dispatch + + # TODO: why is the lambda needed ? + setattr( + class_, + "_compiler_dispatch", + lambda *arg, **kw: existing(*arg, **kw), + ) + setattr(class_, "_compiler_dispatcher", existing) + + if specs: + for s in specs: + existing.specs[s] = fn + + else: + existing.specs["default"] = fn + return fn + + return decorate + + +def deregister(class_): + """Remove all custom compilers associated with a given + :class:`_expression.ClauseElement` type. + + """ + + if hasattr(class_, "_compiler_dispatcher"): + class_._compiler_dispatch = class_._original_compiler_dispatch + del class_._compiler_dispatcher + + +class _dispatcher(object): + def __init__(self): + self.specs = {} + + def __call__(self, element, compiler, **kw): + # TODO: yes, this could also switch off of DBAPI in use. + fn = self.specs.get(compiler.dialect.name, None) + if not fn: + try: + fn = self.specs["default"] + except KeyError as ke: + util.raise_( + exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), + ), + replace_context=ke, + ) + + # if compilation includes add_to_result_map, collect add_to_result_map + # arguments from the user-defined callable, which are probably none + # because this is not public API. if it wasn't called, then call it + # ourselves. + arm = kw.get("add_to_result_map", None) + if arm: + arm_collection = [] + kw["add_to_result_map"] = lambda *args: arm_collection.append(args) + + expr = fn(element, compiler, **kw) + + if arm: + if not arm_collection: + arm_collection.append( + (None, None, (element,), sqltypes.NULLTYPE) + ) + for tup in arm_collection: + arm(*tup) + return expr diff --git a/lib/sqlalchemy/ext/declarative/__init__.py b/lib/sqlalchemy/ext/declarative/__init__.py new file mode 100644 index 0000000..6215e35 --- /dev/null +++ b/lib/sqlalchemy/ext/declarative/__init__.py @@ -0,0 +1,64 @@ +# ext/declarative/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .extensions import AbstractConcreteBase +from .extensions import ConcreteBase +from .extensions import DeferredReflection +from .extensions import instrument_declarative +from ... import util +from ...orm.decl_api import as_declarative as _as_declarative +from ...orm.decl_api import declarative_base as _declarative_base +from ...orm.decl_api import DeclarativeMeta +from ...orm.decl_api import declared_attr +from ...orm.decl_api import has_inherited_table as _has_inherited_table +from ...orm.decl_api import synonym_for as _synonym_for + + +@util.moved_20( + "The ``declarative_base()`` function is now available as " + ":func:`sqlalchemy.orm.declarative_base`." +) +def declarative_base(*arg, **kw): + return _declarative_base(*arg, **kw) + + +@util.moved_20( + "The ``as_declarative()`` function is now available as " + ":func:`sqlalchemy.orm.as_declarative`" +) +def as_declarative(*arg, **kw): + return _as_declarative(*arg, **kw) + + +@util.moved_20( + "The ``has_inherited_table()`` function is now available as " + ":func:`sqlalchemy.orm.has_inherited_table`." +) +def has_inherited_table(*arg, **kw): + return _has_inherited_table(*arg, **kw) + + +@util.moved_20( + "The ``synonym_for()`` function is now available as " + ":func:`sqlalchemy.orm.synonym_for`" +) +def synonym_for(*arg, **kw): + return _synonym_for(*arg, **kw) + + +__all__ = [ + "declarative_base", + "synonym_for", + "has_inherited_table", + "instrument_declarative", + "declared_attr", + "as_declarative", + "ConcreteBase", + "AbstractConcreteBase", + "DeclarativeMeta", + "DeferredReflection", +] diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py new file mode 100644 index 0000000..7818841 --- /dev/null +++ b/lib/sqlalchemy/ext/declarative/extensions.py @@ -0,0 +1,463 @@ +# ext/declarative/extensions.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""Public API functions and helpers for declarative.""" + + +from ... import inspection +from ... import util +from ...orm import exc as orm_exc +from ...orm import registry +from ...orm import relationships +from ...orm.base import _mapper_or_none +from ...orm.clsregistry import _resolver +from ...orm.decl_base import _DeferredMapperConfig +from ...orm.util import polymorphic_union +from ...schema import Table +from ...util import OrderedDict + + +@util.deprecated( + "2.0", + "the instrument_declarative function is deprecated " + "and will be removed in SQLAlhcemy 2.0. Please use " + ":meth:`_orm.registry.map_declaratively", +) +def instrument_declarative(cls, cls_registry, metadata): + """Given a class, configure the class declaratively, + using the given registry, which can be any dictionary, and + MetaData object. + + """ + registry(metadata=metadata, class_registry=cls_registry).map_declaratively( + cls + ) + + +class ConcreteBase(object): + """A helper class for 'concrete' declarative mappings. + + :class:`.ConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_last__()`` function, which is essentially + a hook for the :meth:`.after_configured` event. + + :class:`.ConcreteBase` produces a mapped + table for the class itself. Compare to :class:`.AbstractConcreteBase`, + which does not. + + Example:: + + from sqlalchemy.ext.declarative import ConcreteBase + + class Employee(ConcreteBase, Base): + __tablename__ = 'employee' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', + 'concrete':True} + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + + The name of the discriminator column used by :func:`.polymorphic_union` + defaults to the name ``type``. To suit the use case of a mapping where an + actual column in a mapped table is already named ``type``, the + discriminator name can be configured by setting the + ``_concrete_discriminator_name`` attribute:: + + class Employee(ConcreteBase, Base): + _concrete_discriminator_name = '_concrete_discriminator' + + .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name`` + attribute to :class:`_declarative.ConcreteBase` so that the + virtual discriminator column name can be customized. + + .. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute + need only be placed on the basemost class to take correct effect for + all subclasses. An explicit error message is now raised if the + mapped column names conflict with the discriminator name, whereas + in the 1.3.x series there would be some warnings and then a non-useful + query would be generated. + + .. seealso:: + + :class:`.AbstractConcreteBase` + + :ref:`concrete_inheritance` + + + """ + + @classmethod + def _create_polymorphic_union(cls, mappers, discriminator_name): + return polymorphic_union( + OrderedDict( + (mp.polymorphic_identity, mp.local_table) for mp in mappers + ), + discriminator_name, + "pjoin", + ) + + @classmethod + def __declare_first__(cls): + m = cls.__mapper__ + if m.with_polymorphic: + return + + discriminator_name = ( + getattr(cls, "_concrete_discriminator_name", None) or "type" + ) + + mappers = list(m.self_and_descendants) + pjoin = cls._create_polymorphic_union(mappers, discriminator_name) + m._set_with_polymorphic(("*", pjoin)) + m._set_polymorphic_on(pjoin.c[discriminator_name]) + + +class AbstractConcreteBase(ConcreteBase): + """A helper class for 'concrete' declarative mappings. + + :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_last__()`` function, which is essentially + a hook for the :meth:`.after_configured` event. + + :class:`.AbstractConcreteBase` does produce a mapped class + for the base class, however it is not persisted to any table; it + is instead mapped directly to the "polymorphic" selectable directly + and is only used for selecting. Compare to :class:`.ConcreteBase`, + which does create a persisted table for the base class. + + .. note:: + + The :class:`.AbstractConcreteBase` class does not intend to set up the + mapping for the base class until all the subclasses have been defined, + as it needs to create a mapping against a selectable that will include + all subclass tables. In order to achieve this, it waits for the + **mapper configuration event** to occur, at which point it scans + through all the configured subclasses and sets up a mapping that will + query against all subclasses at once. + + While this event is normally invoked automatically, in the case of + :class:`.AbstractConcreteBase`, it may be necessary to invoke it + explicitly after **all** subclass mappings are defined, if the first + operation is to be a query against this base class. To do so, invoke + :func:`.configure_mappers` once all the desired classes have been + configured:: + + from sqlalchemy.orm import configure_mappers + + configure_mappers() + + .. seealso:: + + :func:`_orm.configure_mappers` + + + Example:: + + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Employee(AbstractConcreteBase, Base): + pass + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + configure_mappers() + + The abstract base class is handled by declarative in a special way; + at class configuration time, it behaves like a declarative mixin + or an ``__abstract__`` base class. Once classes are configured + and mappings are produced, it then gets mapped itself, but + after all of its descendants. This is a very unique system of mapping + not found in any other SQLAlchemy system. + + Using this approach, we can specify columns and properties + that will take place on mapped subclasses, in the way that + we normally do as in :ref:`declarative_mixins`:: + + class Company(Base): + __tablename__ = 'company' + id = Column(Integer, primary_key=True) + + class Employee(AbstractConcreteBase, Base): + employee_id = Column(Integer, primary_key=True) + + @declared_attr + def company_id(cls): + return Column(ForeignKey('company.id')) + + @declared_attr + def company(cls): + return relationship("Company") + + class Manager(Employee): + __tablename__ = 'manager' + + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + configure_mappers() + + When we make use of our mappings however, both ``Manager`` and + ``Employee`` will have an independently usable ``.company`` attribute:: + + session.query(Employee).filter(Employee.company.has(id=5)) + + .. versionchanged:: 1.0.0 - The mechanics of :class:`.AbstractConcreteBase` + have been reworked to support relationships established directly + on the abstract base, without any special configurational steps. + + .. seealso:: + + :class:`.ConcreteBase` + + :ref:`concrete_inheritance` + + """ + + __no_table__ = True + + @classmethod + def __declare_first__(cls): + cls._sa_decl_prepare_nocascade() + + @classmethod + def _sa_decl_prepare_nocascade(cls): + if getattr(cls, "__mapper__", None): + return + + to_map = _DeferredMapperConfig.config_for_cls(cls) + + # can't rely on 'self_and_descendants' here + # since technically an immediate subclass + # might not be mapped, but a subclass + # may be. + mappers = [] + stack = list(cls.__subclasses__()) + while stack: + klass = stack.pop() + stack.extend(klass.__subclasses__()) + mn = _mapper_or_none(klass) + if mn is not None: + mappers.append(mn) + + discriminator_name = ( + getattr(cls, "_concrete_discriminator_name", None) or "type" + ) + pjoin = cls._create_polymorphic_union(mappers, discriminator_name) + + # For columns that were declared on the class, these + # are normally ignored with the "__no_table__" mapping, + # unless they have a different attribute key vs. col name + # and are in the properties argument. + # In that case, ensure we update the properties entry + # to the correct column from the pjoin target table. + declared_cols = set(to_map.declared_columns) + for k, v in list(to_map.properties.items()): + if v in declared_cols: + to_map.properties[k] = pjoin.c[v.key] + + to_map.local_table = pjoin + + m_args = to_map.mapper_args_fn or dict + + def mapper_args(): + args = m_args() + args["polymorphic_on"] = pjoin.c[discriminator_name] + return args + + to_map.mapper_args_fn = mapper_args + + m = to_map.map() + + for scls in cls.__subclasses__(): + sm = _mapper_or_none(scls) + if sm and sm.concrete and cls in scls.__bases__: + sm._set_concrete_base(m) + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of AbstractConcreteBase and " + "has a mapping pending until all subclasses are defined. " + "Call the sqlalchemy.orm.configure_mappers() function after " + "all subclasses have been defined to " + "complete the mapping of this class." + % orm_exc._safe_cls_name(cls), + ) + + +class DeferredReflection(object): + """A helper class for construction of mappings based on + a deferred reflection step. + + Normally, declarative can be used with reflection by + setting a :class:`_schema.Table` object using autoload_with=engine + as the ``__table__`` attribute on a declarative class. + The caveat is that the :class:`_schema.Table` must be fully + reflected, or at the very least have a primary key column, + at the point at which a normal declarative mapping is + constructed, meaning the :class:`_engine.Engine` must be available + at class declaration time. + + The :class:`.DeferredReflection` mixin moves the construction + of mappers to be at a later point, after a specific + method is called which first reflects all :class:`_schema.Table` + objects created so far. Classes can define it as such:: + + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.declarative import DeferredReflection + Base = declarative_base() + + class MyClass(DeferredReflection, Base): + __tablename__ = 'mytable' + + Above, ``MyClass`` is not yet mapped. After a series of + classes have been defined in the above fashion, all tables + can be reflected and mappings created using + :meth:`.prepare`:: + + engine = create_engine("someengine://...") + DeferredReflection.prepare(engine) + + The :class:`.DeferredReflection` mixin can be applied to individual + classes, used as the base for the declarative base itself, + or used in a custom abstract class. Using an abstract base + allows that only a subset of classes to be prepared for a + particular prepare step, which is necessary for applications + that use more than one engine. For example, if an application + has two engines, you might use two bases, and prepare each + separately, e.g.:: + + class ReflectedOne(DeferredReflection, Base): + __abstract__ = True + + class ReflectedTwo(DeferredReflection, Base): + __abstract__ = True + + class MyClass(ReflectedOne): + __tablename__ = 'mytable' + + class MyOtherClass(ReflectedOne): + __tablename__ = 'myothertable' + + class YetAnotherClass(ReflectedTwo): + __tablename__ = 'yetanothertable' + + # ... etc. + + Above, the class hierarchies for ``ReflectedOne`` and + ``ReflectedTwo`` can be configured separately:: + + ReflectedOne.prepare(engine_one) + ReflectedTwo.prepare(engine_two) + + .. seealso:: + + :ref:`orm_declarative_reflected_deferred_reflection` - in the + :ref:`orm_declarative_table_config_toplevel` section. + + """ + + @classmethod + def prepare(cls, engine): + """Reflect all :class:`_schema.Table` objects for all current + :class:`.DeferredReflection` subclasses""" + + to_map = _DeferredMapperConfig.classes_for_base(cls) + + with inspection.inspect(engine)._inspection_context() as insp: + for thingy in to_map: + cls._sa_decl_prepare(thingy.local_table, insp) + thingy.map() + mapper = thingy.cls.__mapper__ + metadata = mapper.class_.metadata + for rel in mapper._props.values(): + if ( + isinstance(rel, relationships.RelationshipProperty) + and rel.secondary is not None + ): + if isinstance(rel.secondary, Table): + cls._reflect_table(rel.secondary, insp) + elif isinstance(rel.secondary, str): + + _, resolve_arg = _resolver(rel.parent.class_, rel) + + rel.secondary = resolve_arg(rel.secondary) + rel.secondary._resolvers += ( + cls._sa_deferred_table_resolver( + insp, metadata + ), + ) + + # controversy! do we resolve it here? or leave + # it deferred? I think doing it here is necessary + # so the connection does not leak. + rel.secondary = rel.secondary() + + @classmethod + def _sa_deferred_table_resolver(cls, inspector, metadata): + def _resolve(key): + t1 = Table(key, metadata) + cls._reflect_table(t1, inspector) + return t1 + + return _resolve + + @classmethod + def _sa_decl_prepare(cls, local_table, inspector): + # autoload Table, which is already + # present in the metadata. This + # will fill in db-loaded columns + # into the existing Table object. + if local_table is not None: + cls._reflect_table(local_table, inspector) + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of DeferredReflection. " + "Mappings are not produced until the .prepare() " + "method is called on the class hierarchy." + % orm_exc._safe_cls_name(cls), + ) + + @classmethod + def _reflect_table(cls, table, inspector): + Table( + table.name, + table.metadata, + extend_existing=True, + autoload_replace=False, + autoload_with=inspector, + schema=table.schema, + ) diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py new file mode 100644 index 0000000..bad076e --- /dev/null +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -0,0 +1,256 @@ +# ext/horizontal_shard.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Horizontal sharding support. + +Defines a rudimental 'horizontal sharding' system which allows a Session to +distribute queries and persistence operations across multiple databases. + +For a usage example, see the :ref:`examples_sharding` example included in +the source distribution. + +""" + +from .. import event +from .. import exc +from .. import inspect +from .. import util +from ..orm.query import Query +from ..orm.session import Session + +__all__ = ["ShardedSession", "ShardedQuery"] + + +class ShardedQuery(Query): + def __init__(self, *args, **kwargs): + super(ShardedQuery, self).__init__(*args, **kwargs) + self.id_chooser = self.session.id_chooser + self.query_chooser = self.session.query_chooser + self.execute_chooser = self.session.execute_chooser + self._shard_id = None + + def set_shard(self, shard_id): + """Return a new query, limited to a single shard ID. + + All subsequent operations with the returned query will + be against the single shard regardless of other state. + + The shard_id can be passed for a 2.0 style execution to the + bind_arguments dictionary of :meth:`.Session.execute`:: + + results = session.execute( + stmt, + bind_arguments={"shard_id": "my_shard"} + ) + + """ + return self.execution_options(_sa_shard_id=shard_id) + + +class ShardedSession(Session): + def __init__( + self, + shard_chooser, + id_chooser, + execute_chooser=None, + shards=None, + query_cls=ShardedQuery, + **kwargs + ): + """Construct a ShardedSession. + + :param shard_chooser: A callable which, passed a Mapper, a mapped + instance, and possibly a SQL clause, returns a shard ID. This id + may be based off of the attributes present within the object, or on + some round-robin scheme. If the scheme is based on a selection, it + should set whatever state on the instance to mark it in the future as + participating in that shard. + + :param id_chooser: A callable, passed a query and a tuple of identity + values, which should return a list of shard ids where the ID might + reside. The databases will be queried in the order of this listing. + + :param execute_chooser: For a given :class:`.ORMExecuteState`, + returns the list of shard_ids + where the query should be issued. Results from all shards returned + will be combined together into a single listing. + + .. versionchanged:: 1.4 The ``execute_chooser`` parameter + supersedes the ``query_chooser`` parameter. + + :param shards: A dictionary of string shard names + to :class:`~sqlalchemy.engine.Engine` objects. + + """ + query_chooser = kwargs.pop("query_chooser", None) + super(ShardedSession, self).__init__(query_cls=query_cls, **kwargs) + + event.listen( + self, "do_orm_execute", execute_and_instances, retval=True + ) + self.shard_chooser = shard_chooser + self.id_chooser = id_chooser + + if query_chooser: + util.warn_deprecated( + "The ``query_choser`` parameter is deprecated; " + "please use ``execute_chooser``.", + "1.4", + ) + if execute_chooser: + raise exc.ArgumentError( + "Can't pass query_chooser and execute_chooser " + "at the same time." + ) + + def execute_chooser(orm_context): + return query_chooser(orm_context.statement) + + self.execute_chooser = execute_chooser + else: + self.execute_chooser = execute_chooser + self.query_chooser = query_chooser + self.__binds = {} + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) + + def _identity_lookup( + self, + mapper, + primary_key_identity, + identity_token=None, + lazy_loaded_from=None, + **kw + ): + """override the default :meth:`.Session._identity_lookup` method so + that we search for a given non-token primary key identity across all + possible identity tokens (e.g. shard ids). + + .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from + the :class:`_query.Query` object to the :class:`.Session`. + + """ + + if identity_token is not None: + return super(ShardedSession, self)._identity_lookup( + mapper, + primary_key_identity, + identity_token=identity_token, + **kw + ) + else: + q = self.query(mapper) + if lazy_loaded_from: + q = q._set_lazyload_from(lazy_loaded_from) + for shard_id in self.id_chooser(q, primary_key_identity): + obj = super(ShardedSession, self)._identity_lookup( + mapper, + primary_key_identity, + identity_token=shard_id, + lazy_loaded_from=lazy_loaded_from, + **kw + ) + if obj is not None: + return obj + + return None + + def _choose_shard_and_assign(self, mapper, instance, **kw): + if instance is not None: + state = inspect(instance) + if state.key: + token = state.key[2] + assert token is not None + return token + elif state.identity_token: + return state.identity_token + + shard_id = self.shard_chooser(mapper, instance, **kw) + if instance is not None: + state.identity_token = shard_id + return shard_id + + def connection_callable( + self, mapper=None, instance=None, shard_id=None, **kwargs + ): + """Provide a :class:`_engine.Connection` to use in the unit of work + flush process. + + """ + + if shard_id is None: + shard_id = self._choose_shard_and_assign(mapper, instance) + + if self.in_transaction(): + return self.get_transaction().connection(mapper, shard_id=shard_id) + else: + return self.get_bind( + mapper, shard_id=shard_id, instance=instance + ).connect(**kwargs) + + def get_bind( + self, mapper=None, shard_id=None, instance=None, clause=None, **kw + ): + if shard_id is None: + shard_id = self._choose_shard_and_assign( + mapper, instance, clause=clause + ) + return self.__binds[shard_id] + + def bind_shard(self, shard_id, bind): + self.__binds[shard_id] = bind + + +def execute_and_instances(orm_context): + if orm_context.is_select: + load_options = active_options = orm_context.load_options + update_options = None + + elif orm_context.is_update or orm_context.is_delete: + load_options = None + update_options = active_options = orm_context.update_delete_options + else: + load_options = update_options = active_options = None + + session = orm_context.session + + def iter_for_shard(shard_id, load_options, update_options): + execution_options = dict(orm_context.local_execution_options) + + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["shard_id"] = shard_id + + if orm_context.is_select: + load_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_load_options"] = load_options + elif orm_context.is_update or orm_context.is_delete: + update_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_update_options"] = update_options + + return orm_context.invoke_statement( + bind_arguments=bind_arguments, execution_options=execution_options + ) + + if active_options and active_options._refresh_identity_token is not None: + shard_id = active_options._refresh_identity_token + elif "_sa_shard_id" in orm_context.execution_options: + shard_id = orm_context.execution_options["_sa_shard_id"] + elif "shard_id" in orm_context.bind_arguments: + shard_id = orm_context.bind_arguments["shard_id"] + else: + shard_id = None + + if shard_id is not None: + return iter_for_shard(shard_id, load_options, update_options) + else: + partial = [] + for shard_id in session.execute_chooser(orm_context): + result_ = iter_for_shard(shard_id, load_options, update_options) + partial.append(result_) + + return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py new file mode 100644 index 0000000..cc0aca6 --- /dev/null +++ b/lib/sqlalchemy/ext/hybrid.py @@ -0,0 +1,1206 @@ +# ext/hybrid.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Define attributes on ORM-mapped classes that have "hybrid" behavior. + +"hybrid" means the attribute has distinct behaviors defined at the +class level and at the instance level. + +The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of +method decorator, is around 50 lines of code and has almost no +dependencies on the rest of SQLAlchemy. It can, in theory, work with +any descriptor-based expression system. + +Consider a mapping ``Interval``, representing integer ``start`` and ``end`` +values. We can define higher level functions on mapped classes that produce SQL +expressions at the class level, and Python expression evaluation at the +instance level. Below, each function decorated with :class:`.hybrid_method` or +:class:`.hybrid_property` may receive ``self`` as an instance of the class, or +as the class itself:: + + from sqlalchemy import Column, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import Session, aliased + from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method + + Base = declarative_base() + + class Interval(Base): + __tablename__ = 'interval' + + id = Column(Integer, primary_key=True) + start = Column(Integer, nullable=False) + end = Column(Integer, nullable=False) + + def __init__(self, start, end): + self.start = start + self.end = end + + @hybrid_property + def length(self): + return self.end - self.start + + @hybrid_method + def contains(self, point): + return (self.start <= point) & (point <= self.end) + + @hybrid_method + def intersects(self, other): + return self.contains(other.start) | self.contains(other.end) + +Above, the ``length`` property returns the difference between the +``end`` and ``start`` attributes. With an instance of ``Interval``, +this subtraction occurs in Python, using normal Python descriptor +mechanics:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + +When dealing with the ``Interval`` class itself, the :class:`.hybrid_property` +descriptor evaluates the function body given the ``Interval`` class as +the argument, which when evaluated with SQLAlchemy expression mechanics +(here using the :attr:`.QueryableAttribute.expression` accessor) +returns a new SQL expression:: + + >>> print(Interval.length.expression) + interval."end" - interval.start + + >>> print(Session().query(Interval).filter(Interval.length > 10)) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end + FROM interval + WHERE interval."end" - interval.start > :param_1 + +ORM methods such as :meth:`_query.Query.filter_by` +generally use ``getattr()`` to +locate attributes, so can also be used with hybrid attributes:: + + >>> print(Session().query(Interval).filter_by(length=5)) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end + FROM interval + WHERE interval."end" - interval.start = :param_1 + +The ``Interval`` class example also illustrates two methods, +``contains()`` and ``intersects()``, decorated with +:class:`.hybrid_method`. This decorator applies the same idea to +methods that :class:`.hybrid_property` applies to attributes. The +methods return boolean values, and take advantage of the Python ``|`` +and ``&`` bitwise operators to produce equivalent instance-level and +SQL expression-level boolean behavior:: + + >>> i1.contains(6) + True + >>> i1.contains(15) + False + >>> i1.intersects(Interval(7, 18)) + True + >>> i1.intersects(Interval(25, 29)) + False + + >>> print(Session().query(Interval).filter(Interval.contains(15))) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end + FROM interval + WHERE interval.start <= :start_1 AND interval."end" > :end_1 + + >>> ia = aliased(Interval) + >>> print(Session().query(Interval, ia).filter(Interval.intersects(ia))) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end, interval_1.id AS interval_1_id, + interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end + FROM interval, interval AS interval_1 + WHERE interval.start <= interval_1.start + AND interval."end" > interval_1.start + OR interval.start <= interval_1."end" + AND interval."end" > interval_1."end" + +.. _hybrid_distinct_expression: + +Defining Expression Behavior Distinct from Attribute Behavior +-------------------------------------------------------------- + +Our usage of the ``&`` and ``|`` bitwise operators above was +fortunate, considering our functions operated on two boolean values to +return a new one. In many cases, the construction of an in-Python +function and a SQLAlchemy SQL expression have enough differences that +two separate Python expressions should be defined. The +:mod:`~sqlalchemy.ext.hybrid` decorators define the +:meth:`.hybrid_property.expression` modifier for this purpose. As an +example we'll define the radius of the interval, which requires the +usage of the absolute value function:: + + from sqlalchemy import func + + class Interval(object): + # ... + + @hybrid_property + def radius(self): + return abs(self.length) / 2 + + @radius.expression + def radius(cls): + return func.abs(cls.length) / 2 + +Above the Python function ``abs()`` is used for instance-level +operations, the SQL function ``ABS()`` is used via the :data:`.func` +object for class-level expressions:: + + >>> i1.radius + 2 + + >>> print(Session().query(Interval).filter(Interval.radius > 5)) + SELECT interval.id AS interval_id, interval.start AS interval_start, + interval."end" AS interval_end + FROM interval + WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1 + +.. note:: When defining an expression for a hybrid property or method, the + expression method **must** retain the name of the original hybrid, else + the new hybrid with the additional state will be attached to the class + with the non-matching name. To use the example above:: + + class Interval(object): + # ... + + @hybrid_property + def radius(self): + return abs(self.length) / 2 + + # WRONG - the non-matching name will cause this function to be + # ignored + @radius.expression + def radius_expression(cls): + return func.abs(cls.length) / 2 + + This is also true for other mutator methods, such as + :meth:`.hybrid_property.update_expression`. This is the same behavior + as that of the ``@property`` construct that is part of standard Python. + +Defining Setters +---------------- + +Hybrid properties can also define setter methods. If we wanted +``length`` above, when set, to modify the endpoint value:: + + class Interval(object): + # ... + + @hybrid_property + def length(self): + return self.end - self.start + + @length.setter + def length(self, value): + self.end = self.start + value + +The ``length(self, value)`` method is now called upon set:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + >>> i1.length = 12 + >>> i1.end + 17 + +.. _hybrid_bulk_update: + +Allowing Bulk ORM Update +------------------------ + +A hybrid can define a custom "UPDATE" handler for when using the +:meth:`_query.Query.update` method, allowing the hybrid to be used in the +SET clause of the update. + +Normally, when using a hybrid with :meth:`_query.Query.update`, the SQL +expression is used as the column that's the target of the SET. If our +``Interval`` class had a hybrid ``start_point`` that linked to +``Interval.start``, this could be substituted directly:: + + session.query(Interval).update({Interval.start_point: 10}) + +However, when using a composite hybrid like ``Interval.length``, this +hybrid represents more than one column. We can set up a handler that will +accommodate a value passed to :meth:`_query.Query.update` which can affect +this, using the :meth:`.hybrid_property.update_expression` decorator. +A handler that works similarly to our setter would be:: + + class Interval(object): + # ... + + @hybrid_property + def length(self): + return self.end - self.start + + @length.setter + def length(self, value): + self.end = self.start + value + + @length.update_expression + def length(cls, value): + return [ + (cls.end, cls.start + value) + ] + +Above, if we use ``Interval.length`` in an UPDATE expression as:: + + session.query(Interval).update( + {Interval.length: 25}, synchronize_session='fetch') + +We'll get an UPDATE statement along the lines of:: + + UPDATE interval SET end=start + :value + +In some cases, the default "evaluate" strategy can't perform the SET +expression in Python; while the addition operator we're using above +is supported, for more complex SET expressions it will usually be necessary +to use either the "fetch" or False synchronization strategy as illustrated +above. + +.. note:: For ORM bulk updates to work with hybrids, the function name + of the hybrid must match that of how it is accessed. Something + like this wouldn't work:: + + class Interval(object): + # ... + + def _get(self): + return self.end - self.start + + def _set(self, value): + self.end = self.start + value + + def _update_expr(cls, value): + return [ + (cls.end, cls.start + value) + ] + + length = hybrid_property( + fget=_get, fset=_set, update_expr=_update_expr + ) + + The Python descriptor protocol does not provide any reliable way for + a descriptor to know what attribute name it was accessed as, and + the UPDATE scheme currently relies upon being able to access the + attribute from an instance by name in order to perform the instance + synchronization step. + +.. versionadded:: 1.2 added support for bulk updates to hybrid properties. + +Working with Relationships +-------------------------- + +There's no essential difference when creating hybrids that work with +related objects as opposed to column-based data. The need for distinct +expressions tends to be greater. The two variants we'll illustrate +are the "join-dependent" hybrid, and the "correlated subquery" hybrid. + +Join-Dependent Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Consider the following declarative +mapping which relates a ``User`` to a ``SavingsAccount``:: + + from sqlalchemy import Column, Integer, ForeignKey, Numeric, String + from sqlalchemy.orm import relationship + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.hybrid import hybrid_property + + Base = declarative_base() + + class SavingsAccount(Base): + __tablename__ = 'account' + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('user.id'), nullable=False) + balance = Column(Numeric(15, 5)) + + class User(Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + + accounts = relationship("SavingsAccount", backref="owner") + + @hybrid_property + def balance(self): + if self.accounts: + return self.accounts[0].balance + else: + return None + + @balance.setter + def balance(self, value): + if not self.accounts: + account = Account(owner=self) + else: + account = self.accounts[0] + account.balance = value + + @balance.expression + def balance(cls): + return SavingsAccount.balance + +The above hybrid property ``balance`` works with the first +``SavingsAccount`` entry in the list of accounts for this user. The +in-Python getter/setter methods can treat ``accounts`` as a Python +list available on ``self``. + +However, at the expression level, it's expected that the ``User`` class will +be used in an appropriate context such that an appropriate join to +``SavingsAccount`` will be present:: + + >>> print(Session().query(User, User.balance). + ... join(User.accounts).filter(User.balance > 5000)) + SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" JOIN account ON "user".id = account.user_id + WHERE account.balance > :balance_1 + +Note however, that while the instance level accessors need to worry +about whether ``self.accounts`` is even present, this issue expresses +itself differently at the SQL expression level, where we basically +would use an outer join:: + + >>> from sqlalchemy import or_ + >>> print (Session().query(User, User.balance).outerjoin(User.accounts). + ... filter(or_(User.balance < 5000, User.balance == None))) + SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id + WHERE account.balance < :balance_1 OR account.balance IS NULL + +Correlated Subquery Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We can, of course, forego being dependent on the enclosing query's usage +of joins in favor of the correlated subquery, which can portably be packed +into a single column expression. A correlated subquery is more portable, but +often performs more poorly at the SQL level. Using the same technique +illustrated at :ref:`mapper_column_property_sql_expressions`, +we can adjust our ``SavingsAccount`` example to aggregate the balances for +*all* accounts, and use a correlated subquery for the column expression:: + + from sqlalchemy import Column, Integer, ForeignKey, Numeric, String + from sqlalchemy.orm import relationship + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy import select, func + + Base = declarative_base() + + class SavingsAccount(Base): + __tablename__ = 'account' + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('user.id'), nullable=False) + balance = Column(Numeric(15, 5)) + + class User(Base): + __tablename__ = 'user' + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + + accounts = relationship("SavingsAccount", backref="owner") + + @hybrid_property + def balance(self): + return sum(acc.balance for acc in self.accounts) + + @balance.expression + def balance(cls): + return select(func.sum(SavingsAccount.balance)).\ + where(SavingsAccount.user_id==cls.id).\ + label('total_balance') + +The above recipe will give us the ``balance`` column which renders +a correlated SELECT:: + + >>> print(s.query(User).filter(User.balance > 400)) + SELECT "user".id AS user_id, "user".name AS user_name + FROM "user" + WHERE (SELECT sum(account.balance) AS sum_1 + FROM account + WHERE account.user_id = "user".id) > :param_1 + +.. _hybrid_custom_comparators: + +Building Custom Comparators +--------------------------- + +The hybrid property also includes a helper that allows construction of +custom comparators. A comparator object allows one to customize the +behavior of each SQLAlchemy expression operator individually. They +are useful when creating custom types that have some highly +idiosyncratic behavior on the SQL side. + +.. note:: The :meth:`.hybrid_property.comparator` decorator introduced + in this section **replaces** the use of the + :meth:`.hybrid_property.expression` decorator. + They cannot be used together. + +The example class below allows case-insensitive comparisons on the attribute +named ``word_insensitive``:: + + from sqlalchemy.ext.hybrid import Comparator, hybrid_property + from sqlalchemy import func, Column, Integer, String + from sqlalchemy.orm import Session + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class CaseInsensitiveComparator(Comparator): + def __eq__(self, other): + return func.lower(self.__clause_element__()) == func.lower(other) + + class SearchWord(Base): + __tablename__ = 'searchword' + id = Column(Integer, primary_key=True) + word = Column(String(255), nullable=False) + + @hybrid_property + def word_insensitive(self): + return self.word.lower() + + @word_insensitive.comparator + def word_insensitive(cls): + return CaseInsensitiveComparator(cls.word) + +Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()`` +SQL function to both sides:: + + >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks")) + SELECT searchword.id AS searchword_id, searchword.word AS searchword_word + FROM searchword + WHERE lower(searchword.word) = lower(:lower_1) + +The ``CaseInsensitiveComparator`` above implements part of the +:class:`.ColumnOperators` interface. A "coercion" operation like +lowercasing can be applied to all comparison operations (i.e. ``eq``, +``lt``, ``gt``, etc.) using :meth:`.Operators.operate`:: + + class CaseInsensitiveComparator(Comparator): + def operate(self, op, other, **kwargs): + return op( + func.lower(self.__clause_element__()), + func.lower(other), + **kwargs, + ) + +.. _hybrid_reuse_subclass: + +Reusing Hybrid Properties across Subclasses +------------------------------------------- + +A hybrid can be referred to from a superclass, to allow modifying +methods like :meth:`.hybrid_property.getter`, :meth:`.hybrid_property.setter` +to be used to redefine those methods on a subclass. This is similar to +how the standard Python ``@property`` object works:: + + class FirstNameOnly(Base): + # ... + + first_name = Column(String) + + @hybrid_property + def name(self): + return self.first_name + + @name.setter + def name(self, value): + self.first_name = value + + class FirstNameLastName(FirstNameOnly): + # ... + + last_name = Column(String) + + @FirstNameOnly.name.getter + def name(self): + return self.first_name + ' ' + self.last_name + + @name.setter + def name(self, value): + self.first_name, self.last_name = value.split(' ', 1) + +Above, the ``FirstNameLastName`` class refers to the hybrid from +``FirstNameOnly.name`` to repurpose its getter and setter for the subclass. + +When overriding :meth:`.hybrid_property.expression` and +:meth:`.hybrid_property.comparator` alone as the first reference to the +superclass, these names conflict with the same-named accessors on the class- +level :class:`.QueryableAttribute` object returned at the class level. To +override these methods when referring directly to the parent class descriptor, +add the special qualifier :attr:`.hybrid_property.overrides`, which will de- +reference the instrumented attribute back to the hybrid object:: + + class FirstNameLastName(FirstNameOnly): + # ... + + last_name = Column(String) + + @FirstNameOnly.name.overrides.expression + def name(cls): + return func.concat(cls.first_name, ' ', cls.last_name) + +.. versionadded:: 1.2 Added :meth:`.hybrid_property.getter` as well as the + ability to redefine accessors per-subclass. + + +Hybrid Value Objects +-------------------- + +Note in our previous example, if we were to compare the ``word_insensitive`` +attribute of a ``SearchWord`` instance to a plain Python string, the plain +Python string would not be coerced to lower case - the +``CaseInsensitiveComparator`` we built, being returned by +``@word_insensitive.comparator``, only applies to the SQL side. + +A more comprehensive form of the custom comparator is to construct a *Hybrid +Value Object*. This technique applies the target value or expression to a value +object which is then returned by the accessor in all cases. The value object +allows control of all operations upon the value as well as how compared values +are treated, both on the SQL expression side as well as the Python value side. +Replacing the previous ``CaseInsensitiveComparator`` class with a new +``CaseInsensitiveWord`` class:: + + class CaseInsensitiveWord(Comparator): + "Hybrid value representing a lower case representation of a word." + + def __init__(self, word): + if isinstance(word, basestring): + self.word = word.lower() + elif isinstance(word, CaseInsensitiveWord): + self.word = word.word + else: + self.word = func.lower(word) + + def operate(self, op, other, **kwargs): + if not isinstance(other, CaseInsensitiveWord): + other = CaseInsensitiveWord(other) + return op(self.word, other.word, **kwargs) + + def __clause_element__(self): + return self.word + + def __str__(self): + return self.word + + key = 'word' + "Label to apply to Query tuple results" + +Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may +be a SQL function, or may be a Python native. By overriding ``operate()`` and +``__clause_element__()`` to work in terms of ``self.word``, all comparison +operations will work against the "converted" form of ``word``, whether it be +SQL side or Python side. Our ``SearchWord`` class can now deliver the +``CaseInsensitiveWord`` object unconditionally from a single hybrid call:: + + class SearchWord(Base): + __tablename__ = 'searchword' + id = Column(Integer, primary_key=True) + word = Column(String(255), nullable=False) + + @hybrid_property + def word_insensitive(self): + return CaseInsensitiveWord(self.word) + +The ``word_insensitive`` attribute now has case-insensitive comparison behavior +universally, including SQL expression vs. Python expression (note the Python +value is converted to lower case on the Python side here):: + + >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks")) + SELECT searchword.id AS searchword_id, searchword.word AS searchword_word + FROM searchword + WHERE lower(searchword.word) = :lower_1 + +SQL expression versus SQL expression:: + + >>> sw1 = aliased(SearchWord) + >>> sw2 = aliased(SearchWord) + >>> print(Session().query( + ... sw1.word_insensitive, + ... sw2.word_insensitive).\ + ... filter( + ... sw1.word_insensitive > sw2.word_insensitive + ... )) + SELECT lower(searchword_1.word) AS lower_1, + lower(searchword_2.word) AS lower_2 + FROM searchword AS searchword_1, searchword AS searchword_2 + WHERE lower(searchword_1.word) > lower(searchword_2.word) + +Python only expression:: + + >>> ws1 = SearchWord(word="SomeWord") + >>> ws1.word_insensitive == "sOmEwOrD" + True + >>> ws1.word_insensitive == "XOmEwOrX" + False + >>> print(ws1.word_insensitive) + someword + +The Hybrid Value pattern is very useful for any kind of value that may have +multiple representations, such as timestamps, time deltas, units of +measurement, currencies and encrypted passwords. + +.. seealso:: + + `Hybrids and Value Agnostic Types + <https://techspot.zzzeek.org/2011/10/21/hybrids-and-value-agnostic-types/>`_ + - on the techspot.zzzeek.org blog + + `Value Agnostic Types, Part II + <https://techspot.zzzeek.org/2011/10/29/value-agnostic-types-part-ii/>`_ - + on the techspot.zzzeek.org blog + +.. _hybrid_transformers: + +Building Transformers +---------------------- + +A *transformer* is an object which can receive a :class:`_query.Query` +object and +return a new one. The :class:`_query.Query` object includes a method +:meth:`.with_transformation` that returns a new :class:`_query.Query` +transformed by +the given function. + +We can combine this with the :class:`.Comparator` class to produce one type +of recipe which can both set up the FROM clause of a query as well as assign +filtering criterion. + +Consider a mapped class ``Node``, which assembles using adjacency list into a +hierarchical tree pattern:: + + from sqlalchemy import Column, Integer, ForeignKey + from sqlalchemy.orm import relationship + from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() + + class Node(Base): + __tablename__ = 'node' + id = Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey('node.id')) + parent = relationship("Node", remote_side=id) + +Suppose we wanted to add an accessor ``grandparent``. This would return the +``parent`` of ``Node.parent``. When we have an instance of ``Node``, this is +simple:: + + from sqlalchemy.ext.hybrid import hybrid_property + + class Node(Base): + # ... + + @hybrid_property + def grandparent(self): + return self.parent.parent + +For the expression, things are not so clear. We'd need to construct a +:class:`_query.Query` where we :meth:`_query.Query.join` twice along +``Node.parent`` to get to the ``grandparent``. We can instead return a +transforming callable that we'll combine with the :class:`.Comparator` class to +receive any :class:`_query.Query` object, and return a new one that's joined to +the ``Node.parent`` attribute and filtered based on the given criterion:: + + from sqlalchemy.ext.hybrid import Comparator + + class GrandparentTransformer(Comparator): + def operate(self, op, other, **kwargs): + def transform(q): + cls = self.__clause_element__() + parent_alias = aliased(cls) + return q.join(parent_alias, cls.parent).filter( + op(parent_alias.parent, other, **kwargs) + ) + + return transform + + Base = declarative_base() + + class Node(Base): + __tablename__ = 'node' + id = Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey('node.id')) + parent = relationship("Node", remote_side=id) + + @hybrid_property + def grandparent(self): + return self.parent.parent + + @grandparent.comparator + def grandparent(cls): + return GrandparentTransformer(cls) + +The ``GrandparentTransformer`` overrides the core :meth:`.Operators.operate` +method at the base of the :class:`.Comparator` hierarchy to return a query- +transforming callable, which then runs the given comparison operation in a +particular context. Such as, in the example above, the ``operate`` method is +called, given the :attr:`.Operators.eq` callable as well as the right side of +the comparison ``Node(id=5)``. A function ``transform`` is then returned which +will transform a :class:`_query.Query` first to join to ``Node.parent``, +then to +compare ``parent_alias`` using :attr:`.Operators.eq` against the left and right +sides, passing into :meth:`_query.Query.filter`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.orm import Session + >>> session = Session() + {sql}>>> session.query(Node).\ + ... with_transformation(Node.grandparent==Node(id=5)).\ + ... all() + SELECT node.id AS node_id, node.parent_id AS node_parent_id + FROM node JOIN node AS node_1 ON node_1.id = node.parent_id + WHERE :param_1 = node_1.parent_id + {stop} + +We can modify the pattern to be more verbose but flexible by separating the +"join" step from the "filter" step. The tricky part here is ensuring that +successive instances of ``GrandparentTransformer`` use the same +:class:`.AliasedClass` object against ``Node``. Below we use a simple +memoizing approach that associates a ``GrandparentTransformer`` with each +class:: + + class Node(Base): + + # ... + + @grandparent.comparator + def grandparent(cls): + # memoize a GrandparentTransformer + # per class + if '_gp' not in cls.__dict__: + cls._gp = GrandparentTransformer(cls) + return cls._gp + + class GrandparentTransformer(Comparator): + + def __init__(self, cls): + self.parent_alias = aliased(cls) + + @property + def join(self): + def go(q): + return q.join(self.parent_alias, Node.parent) + return go + + def operate(self, op, other, **kwargs): + return op(self.parent_alias.parent, other, **kwargs) + +.. sourcecode:: pycon+sql + + {sql}>>> session.query(Node).\ + ... with_transformation(Node.grandparent.join).\ + ... filter(Node.grandparent==Node(id=5)) + SELECT node.id AS node_id, node.parent_id AS node_parent_id + FROM node JOIN node AS node_1 ON node_1.id = node.parent_id + WHERE :param_1 = node_1.parent_id + {stop} + +The "transformer" pattern is an experimental pattern that starts to make usage +of some functional programming paradigms. While it's only recommended for +advanced and/or patient developers, there's probably a whole lot of amazing +things it can be used for. + +""" # noqa +from .. import util +from ..orm import attributes +from ..orm import interfaces + +HYBRID_METHOD = util.symbol("HYBRID_METHOD") +"""Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + +""" + +HYBRID_PROPERTY = util.symbol("HYBRID_PROPERTY") +"""Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + +""" + + +class hybrid_method(interfaces.InspectionAttrInfo): + """A decorator which allows definition of a Python object method with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HYBRID_METHOD + + def __init__(self, func, expr=None): + """Create a new :class:`.hybrid_method`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_method + + class SomeClass(object): + @hybrid_method + def value(self, x, y): + return self._value + x + y + + @value.expression + def value(self, x, y): + return func.some_function(self._value, x, y) + + """ + self.func = func + self.expression(expr or func) + + def __get__(self, instance, owner): + if instance is None: + return self.expr.__get__(owner, owner.__class__) + else: + return self.func.__get__(instance, owner) + + def expression(self, expr): + """Provide a modifying decorator that defines a + SQL-expression producing method.""" + + self.expr = expr + if not self.expr.__doc__: + self.expr.__doc__ = self.func.__doc__ + return self + + +class hybrid_property(interfaces.InspectionAttrInfo): + """A decorator which allows definition of a Python descriptor with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HYBRID_PROPERTY + + def __init__( + self, + fget, + fset=None, + fdel=None, + expr=None, + custom_comparator=None, + update_expr=None, + ): + """Create a new :class:`.hybrid_property`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_property + + class SomeClass(object): + @hybrid_property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value + + """ + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = expr + self.custom_comparator = custom_comparator + self.update_expr = update_expr + util.update_wrapper(self, fget) + + def __get__(self, instance, owner): + if instance is None: + return self._expr_comparator(owner) + else: + return self.fget(instance) + + def __set__(self, instance, value): + if self.fset is None: + raise AttributeError("can't set attribute") + self.fset(instance, value) + + def __delete__(self, instance): + if self.fdel is None: + raise AttributeError("can't delete attribute") + self.fdel(instance) + + def _copy(self, **kw): + defaults = { + key: value + for key, value in self.__dict__.items() + if not key.startswith("_") + } + defaults.update(**kw) + return type(self)(**defaults) + + @property + def overrides(self): + """Prefix for a method that is overriding an existing attribute. + + The :attr:`.hybrid_property.overrides` accessor just returns + this hybrid object, which when called at the class level from + a parent class, will de-reference the "instrumented attribute" + normally returned at this level, and allow modifying decorators + like :meth:`.hybrid_property.expression` and + :meth:`.hybrid_property.comparator` + to be used without conflicting with the same-named attributes + normally present on the :class:`.QueryableAttribute`:: + + class SuperClass(object): + # ... + + @hybrid_property + def foobar(self): + return self._foobar + + class SubClass(SuperClass): + # ... + + @SuperClass.foobar.overrides.expression + def foobar(cls): + return func.subfoobar(self._foobar) + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`hybrid_reuse_subclass` + + """ + return self + + def getter(self, fget): + """Provide a modifying decorator that defines a getter method. + + .. versionadded:: 1.2 + + """ + + return self._copy(fget=fget) + + def setter(self, fset): + """Provide a modifying decorator that defines a setter method.""" + + return self._copy(fset=fset) + + def deleter(self, fdel): + """Provide a modifying decorator that defines a deletion method.""" + + return self._copy(fdel=fdel) + + def expression(self, expr): + """Provide a modifying decorator that defines a SQL-expression + producing method. + + When a hybrid is invoked at the class level, the SQL expression given + here is wrapped inside of a specialized :class:`.QueryableAttribute`, + which is the same kind of object used by the ORM to represent other + mapped attributes. The reason for this is so that other class-level + attributes such as docstrings and a reference to the hybrid itself may + be maintained within the structure that's returned, without any + modifications to the original SQL expression passed in. + + .. note:: + + When referring to a hybrid property from an owning class (e.g. + ``SomeClass.some_hybrid``), an instance of + :class:`.QueryableAttribute` is returned, representing the + expression or comparator object as well as this hybrid object. + However, that object itself has accessors called ``expression`` and + ``comparator``; so when attempting to override these decorators on a + subclass, it may be necessary to qualify it using the + :attr:`.hybrid_property.overrides` modifier first. See that + modifier for details. + + .. seealso:: + + :ref:`hybrid_distinct_expression` + + """ + + return self._copy(expr=expr) + + def comparator(self, comparator): + """Provide a modifying decorator that defines a custom + comparator producing method. + + The return value of the decorated method should be an instance of + :class:`~.hybrid.Comparator`. + + .. note:: The :meth:`.hybrid_property.comparator` decorator + **replaces** the use of the :meth:`.hybrid_property.expression` + decorator. They cannot be used together. + + When a hybrid is invoked at the class level, the + :class:`~.hybrid.Comparator` object given here is wrapped inside of a + specialized :class:`.QueryableAttribute`, which is the same kind of + object used by the ORM to represent other mapped attributes. The + reason for this is so that other class-level attributes such as + docstrings and a reference to the hybrid itself may be maintained + within the structure that's returned, without any modifications to the + original comparator object passed in. + + .. note:: + + When referring to a hybrid property from an owning class (e.g. + ``SomeClass.some_hybrid``), an instance of + :class:`.QueryableAttribute` is returned, representing the + expression or comparator object as this hybrid object. However, + that object itself has accessors called ``expression`` and + ``comparator``; so when attempting to override these decorators on a + subclass, it may be necessary to qualify it using the + :attr:`.hybrid_property.overrides` modifier first. See that + modifier for details. + + """ + return self._copy(custom_comparator=comparator) + + def update_expression(self, meth): + """Provide a modifying decorator that defines an UPDATE tuple + producing method. + + The method accepts a single value, which is the value to be + rendered into the SET clause of an UPDATE statement. The method + should then process this value into individual column expressions + that fit into the ultimate SET clause, and return them as a + sequence of 2-tuples. Each tuple + contains a column expression as the key and a value to be rendered. + + E.g.:: + + class Person(Base): + # ... + + first_name = Column(String) + last_name = Column(String) + + @hybrid_property + def fullname(self): + return first_name + " " + last_name + + @fullname.update_expression + def fullname(cls, value): + fname, lname = value.split(" ", 1) + return [ + (cls.first_name, fname), + (cls.last_name, lname) + ] + + .. versionadded:: 1.2 + + """ + return self._copy(update_expr=meth) + + @util.memoized_property + def _expr_comparator(self): + if self.custom_comparator is not None: + return self._get_comparator(self.custom_comparator) + elif self.expr is not None: + return self._get_expr(self.expr) + else: + return self._get_expr(self.fget) + + def _get_expr(self, expr): + def _expr(cls): + return ExprComparator(cls, expr(cls), self) + + util.update_wrapper(_expr, expr) + + return self._get_comparator(_expr) + + def _get_comparator(self, comparator): + + proxy_attr = attributes.create_proxied_attribute(self) + + def expr_comparator(owner): + # because this is the descriptor protocol, we don't really know + # what our attribute name is. so search for it through the + # MRO. + for lookup in owner.__mro__: + if self.__name__ in lookup.__dict__: + if lookup.__dict__[self.__name__] is self: + name = self.__name__ + break + else: + name = attributes.NO_KEY + + return proxy_attr( + owner, + name, + self, + comparator(owner), + doc=comparator.__doc__ or self.__doc__, + ) + + return expr_comparator + + +class Comparator(interfaces.PropComparator): + """A helper class that allows easy construction of custom + :class:`~.orm.interfaces.PropComparator` + classes for usage with hybrids.""" + + property = None + + def __init__(self, expression): + self.expression = expression + + def __clause_element__(self): + expr = self.expression + if hasattr(expr, "__clause_element__"): + expr = expr.__clause_element__() + return expr + + def adapt_to_entity(self, adapt_to_entity): + # interesting.... + return self + + +class ExprComparator(Comparator): + def __init__(self, cls, expression, hybrid): + self.cls = cls + self.expression = expression + self.hybrid = hybrid + + def __getattr__(self, key): + return getattr(self.expression, key) + + @property + def info(self): + return self.hybrid.info + + def _bulk_update_tuples(self, value): + if isinstance(self.expression, attributes.QueryableAttribute): + return self.expression._bulk_update_tuples(value) + elif self.hybrid.update_expr is not None: + return self.hybrid.update_expr(self.cls, value) + else: + return [(self.expression, value)] + + @property + def property(self): + return self.expression.property + + def operate(self, op, *other, **kwargs): + return op(self.expression, *other, **kwargs) + + def reverse_operate(self, op, other, **kwargs): + return op(other, self.expression, **kwargs) diff --git a/lib/sqlalchemy/ext/indexable.py b/lib/sqlalchemy/ext/indexable.py new file mode 100644 index 0000000..7cbac54 --- /dev/null +++ b/lib/sqlalchemy/ext/indexable.py @@ -0,0 +1,352 @@ +# ext/index.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Define attributes on ORM-mapped classes that have "index" attributes for +columns with :class:`_types.Indexable` types. + +"index" means the attribute is associated with an element of an +:class:`_types.Indexable` column with the predefined index to access it. +The :class:`_types.Indexable` types include types such as +:class:`_types.ARRAY`, :class:`_types.JSON` and +:class:`_postgresql.HSTORE`. + + + +The :mod:`~sqlalchemy.ext.indexable` extension provides +:class:`_schema.Column`-like interface for any element of an +:class:`_types.Indexable` typed column. In simple cases, it can be +treated as a :class:`_schema.Column` - mapped attribute. + + +.. versionadded:: 1.1 + +Synopsis +======== + +Given ``Person`` as a model with a primary key and JSON data field. +While this field may have any number of elements encoded within it, +we would like to refer to the element called ``name`` individually +as a dedicated attribute which behaves like a standalone column:: + + from sqlalchemy import Column, JSON, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.indexable import index_property + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + name = index_property('data', 'name') + + +Above, the ``name`` attribute now behaves like a mapped column. We +can compose a new ``Person`` and set the value of ``name``:: + + >>> person = Person(name='Alchemist') + +The value is now accessible:: + + >>> person.name + 'Alchemist' + +Behind the scenes, the JSON field was initialized to a new blank dictionary +and the field was set:: + + >>> person.data + {"name": "Alchemist'} + +The field is mutable in place:: + + >>> person.name = 'Renamed' + >>> person.name + 'Renamed' + >>> person.data + {'name': 'Renamed'} + +When using :class:`.index_property`, the change that we make to the indexable +structure is also automatically tracked as history; we no longer need +to use :class:`~.mutable.MutableDict` in order to track this change +for the unit of work. + +Deletions work normally as well:: + + >>> del person.name + >>> person.data + {} + +Above, deletion of ``person.name`` deletes the value from the dictionary, +but not the dictionary itself. + +A missing key will produce ``AttributeError``:: + + >>> person = Person() + >>> person.name + ... + AttributeError: 'name' + +Unless you set a default value:: + + >>> class Person(Base): + >>> __tablename__ = 'person' + >>> + >>> id = Column(Integer, primary_key=True) + >>> data = Column(JSON) + >>> + >>> name = index_property('data', 'name', default=None) # See default + + >>> person = Person() + >>> print(person.name) + None + + +The attributes are also accessible at the class level. +Below, we illustrate ``Person.name`` used to generate +an indexed SQL criteria:: + + >>> from sqlalchemy.orm import Session + >>> session = Session() + >>> query = session.query(Person).filter(Person.name == 'Alchemist') + +The above query is equivalent to:: + + >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist') + +Multiple :class:`.index_property` objects can be chained to produce +multiple levels of indexing:: + + from sqlalchemy import Column, JSON, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.indexable import index_property + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + birthday = index_property('data', 'birthday') + year = index_property('birthday', 'year') + month = index_property('birthday', 'month') + day = index_property('birthday', 'day') + +Above, a query such as:: + + q = session.query(Person).filter(Person.year == '1980') + +On a PostgreSQL backend, the above query will render as:: + + SELECT person.id, person.data + FROM person + WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s + +Default Values +============== + +:class:`.index_property` includes special behaviors for when the indexed +data structure does not exist, and a set operation is called: + +* For an :class:`.index_property` that is given an integer index value, + the default data structure will be a Python list of ``None`` values, + at least as long as the index value; the value is then set at its + place in the list. This means for an index value of zero, the list + will be initialized to ``[None]`` before setting the given value, + and for an index value of five, the list will be initialized to + ``[None, None, None, None, None]`` before setting the fifth element + to the given value. Note that an existing list is **not** extended + in place to receive a value. + +* for an :class:`.index_property` that is given any other kind of index + value (e.g. strings usually), a Python dictionary is used as the + default data structure. + +* The default data structure can be set to any Python callable using the + :paramref:`.index_property.datatype` parameter, overriding the previous + rules. + + +Subclassing +=========== + +:class:`.index_property` can be subclassed, in particular for the common +use case of providing coercion of values or SQL expressions as they are +accessed. Below is a common recipe for use with a PostgreSQL JSON type, +where we want to also include automatic casting plus ``astext()``:: + + class pg_json_property(index_property): + def __init__(self, attr_name, index, cast_type): + super(pg_json_property, self).__init__(attr_name, index) + self.cast_type = cast_type + + def expr(self, model): + expr = super(pg_json_property, self).expr(model) + return expr.astext.cast(self.cast_type) + +The above subclass can be used with the PostgreSQL-specific +version of :class:`_postgresql.JSON`:: + + from sqlalchemy import Column, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.dialects.postgresql import JSON + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + age = pg_json_property('data', 'age', Integer) + +The ``age`` attribute at the instance level works as before; however +when rendering SQL, PostgreSQL's ``->>`` operator will be used +for indexed access, instead of the usual index operator of ``->``:: + + >>> query = session.query(Person).filter(Person.age < 20) + +The above query will render:: + + SELECT person.id, person.data + FROM person + WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s + +""" # noqa +from __future__ import absolute_import + +from .. import inspect +from .. import util +from ..ext.hybrid import hybrid_property +from ..orm.attributes import flag_modified + + +__all__ = ["index_property"] + + +class index_property(hybrid_property): # noqa + """A property generator. The generated property describes an object + attribute that corresponds to an :class:`_types.Indexable` + column. + + .. versionadded:: 1.1 + + .. seealso:: + + :mod:`sqlalchemy.ext.indexable` + + """ + + _NO_DEFAULT_ARGUMENT = object() + + def __init__( + self, + attr_name, + index, + default=_NO_DEFAULT_ARGUMENT, + datatype=None, + mutable=True, + onebased=True, + ): + """Create a new :class:`.index_property`. + + :param attr_name: + An attribute name of an `Indexable` typed column, or other + attribute that returns an indexable structure. + :param index: + The index to be used for getting and setting this value. This + should be the Python-side index value for integers. + :param default: + A value which will be returned instead of `AttributeError` + when there is not a value at given index. + :param datatype: default datatype to use when the field is empty. + By default, this is derived from the type of index used; a + Python list for an integer index, or a Python dictionary for + any other style of index. For a list, the list will be + initialized to a list of None values that is at least + ``index`` elements long. + :param mutable: if False, writes and deletes to the attribute will + be disallowed. + :param onebased: assume the SQL representation of this value is + one-based; that is, the first index in SQL is 1, not zero. + """ + + if mutable: + super(index_property, self).__init__( + self.fget, self.fset, self.fdel, self.expr + ) + else: + super(index_property, self).__init__( + self.fget, None, None, self.expr + ) + self.attr_name = attr_name + self.index = index + self.default = default + is_numeric = isinstance(index, int) + onebased = is_numeric and onebased + + if datatype is not None: + self.datatype = datatype + else: + if is_numeric: + self.datatype = lambda: [None for x in range(index + 1)] + else: + self.datatype = dict + self.onebased = onebased + + def _fget_default(self, err=None): + if self.default == self._NO_DEFAULT_ARGUMENT: + util.raise_(AttributeError(self.attr_name), replace_context=err) + else: + return self.default + + def fget(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + return self._fget_default() + try: + value = column_value[self.index] + except (KeyError, IndexError) as err: + return self._fget_default(err) + else: + return value + + def fset(self, instance, value): + attr_name = self.attr_name + column_value = getattr(instance, attr_name, None) + if column_value is None: + column_value = self.datatype() + setattr(instance, attr_name, column_value) + column_value[self.index] = value + setattr(instance, attr_name, column_value) + if attr_name in inspect(instance).mapper.attrs: + flag_modified(instance, attr_name) + + def fdel(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + raise AttributeError(self.attr_name) + try: + del column_value[self.index] + except KeyError as err: + util.raise_(AttributeError(self.attr_name), replace_context=err) + else: + setattr(instance, attr_name, column_value) + flag_modified(instance, attr_name) + + def expr(self, model): + column = getattr(model, self.attr_name) + index = self.index + if self.onebased: + index += 1 + return column[index] diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py new file mode 100644 index 0000000..54f3e64 --- /dev/null +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -0,0 +1,416 @@ +"""Extensible class instrumentation. + +The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate +systems of class instrumentation within the ORM. Class instrumentation +refers to how the ORM places attributes on the class which maintain +data and track changes to that data, as well as event hooks installed +on the class. + +.. note:: + The extension package is provided for the benefit of integration + with other object management packages, which already perform + their own instrumentation. It is not intended for general use. + +For examples of how the instrumentation extension is used, +see the example :ref:`examples_instrumentation`. + +""" +import weakref + +from .. import util +from ..orm import attributes +from ..orm import base as orm_base +from ..orm import collections +from ..orm import exc as orm_exc +from ..orm import instrumentation as orm_instrumentation +from ..orm.instrumentation import _default_dict_getter +from ..orm.instrumentation import _default_manager_getter +from ..orm.instrumentation import _default_state_getter +from ..orm.instrumentation import ClassManager +from ..orm.instrumentation import InstrumentationFactory + + +INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__" +"""Attribute, elects custom instrumentation when present on a mapped class. + +Allows a class to specify a slightly or wildly different technique for +tracking changes made to mapped attributes and collections. + +Only one instrumentation implementation is allowed in a given object +inheritance hierarchy. + +The value of this attribute must be a callable and will be passed a class +object. The callable must return one of: + + - An instance of an :class:`.InstrumentationManager` or subclass + - An object implementing all or some of InstrumentationManager (TODO) + - A dictionary of callables, implementing all or some of the above (TODO) + - An instance of a :class:`.ClassManager` or subclass + +This attribute is consulted by SQLAlchemy instrumentation +resolution, once the :mod:`sqlalchemy.ext.instrumentation` module +has been imported. If custom finders are installed in the global +instrumentation_finders list, they may or may not choose to honor this +attribute. + +""" + + +def find_native_user_instrumentation_hook(cls): + """Find user-specified instrumentation management for a class.""" + return getattr(cls, INSTRUMENTATION_MANAGER, None) + + +instrumentation_finders = [find_native_user_instrumentation_hook] +"""An extensible sequence of callables which return instrumentation +implementations + +When a class is registered, each callable will be passed a class object. +If None is returned, the +next finder in the sequence is consulted. Otherwise the return must be an +instrumentation factory that follows the same guidelines as +sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER. + +By default, the only finder is find_native_user_instrumentation_hook, which +searches for INSTRUMENTATION_MANAGER. If all finders return None, standard +ClassManager instrumentation is used. + +""" + + +class ExtendedInstrumentationRegistry(InstrumentationFactory): + """Extends :class:`.InstrumentationFactory` with additional + bookkeeping, to accommodate multiple types of + class managers. + + """ + + _manager_finders = weakref.WeakKeyDictionary() + _state_finders = weakref.WeakKeyDictionary() + _dict_finders = weakref.WeakKeyDictionary() + _extended = False + + def _locate_extended_factory(self, class_): + for finder in instrumentation_finders: + factory = finder(class_) + if factory is not None: + manager = self._extended_class_manager(class_, factory) + return manager, factory + else: + return None, None + + def _check_conflicts(self, class_, factory): + existing_factories = self._collect_management_factories_for( + class_ + ).difference([factory]) + if existing_factories: + raise TypeError( + "multiple instrumentation implementations specified " + "in %s inheritance hierarchy: %r" + % (class_.__name__, list(existing_factories)) + ) + + def _extended_class_manager(self, class_, factory): + manager = factory(class_) + if not isinstance(manager, ClassManager): + manager = _ClassInstrumentationAdapter(class_, manager) + + if factory != ClassManager and not self._extended: + # somebody invoked a custom ClassManager. + # reinstall global "getter" functions with the more + # expensive ones. + self._extended = True + _install_instrumented_lookups() + + self._manager_finders[class_] = manager.manager_getter() + self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() + return manager + + def _collect_management_factories_for(self, cls): + """Return a collection of factories in play or specified for a + hierarchy. + + Traverses the entire inheritance graph of a cls and returns a + collection of instrumentation factories for those classes. Factories + are extracted from active ClassManagers, if available, otherwise + instrumentation_finders is consulted. + + """ + hierarchy = util.class_hierarchy(cls) + factories = set() + for member in hierarchy: + manager = self.manager_of_class(member) + if manager is not None: + factories.add(manager.factory) + else: + for finder in instrumentation_finders: + factory = finder(member) + if factory is not None: + break + else: + factory = None + factories.add(factory) + factories.discard(None) + return factories + + def unregister(self, class_): + super(ExtendedInstrumentationRegistry, self).unregister(class_) + if class_ in self._manager_finders: + del self._manager_finders[class_] + del self._state_finders[class_] + del self._dict_finders[class_] + + def manager_of_class(self, cls): + if cls is None: + return None + try: + finder = self._manager_finders.get(cls, _default_manager_getter) + except TypeError: + # due to weakref lookup on invalid object + return None + else: + return finder(cls) + + def state_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._state_finders.get( + instance.__class__, _default_state_getter + )(instance) + + def dict_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._dict_finders.get( + instance.__class__, _default_dict_getter + )(instance) + + +orm_instrumentation._instrumentation_factory = ( + _instrumentation_factory +) = ExtendedInstrumentationRegistry() +orm_instrumentation.instrumentation_finders = instrumentation_finders + + +class InstrumentationManager(object): + """User-defined class instrumentation extension. + + :class:`.InstrumentationManager` can be subclassed in order + to change + how class instrumentation proceeds. This class exists for + the purposes of integration with other object management + frameworks which would like to entirely modify the + instrumentation methodology of the ORM, and is not intended + for regular usage. For interception of class instrumentation + events, see :class:`.InstrumentationEvents`. + + The API for this class should be considered as semi-stable, + and may change slightly with new releases. + + """ + + # r4361 added a mandatory (cls) constructor to this interface. + # given that, perhaps class_ should be dropped from all of these + # signatures. + + def __init__(self, class_): + pass + + def manage(self, class_, manager): + setattr(class_, "_default_class_manager", manager) + + def unregister(self, class_, manager): + delattr(class_, "_default_class_manager") + + def manager_getter(self, class_): + def get(cls): + return cls._default_class_manager + + return get + + def instrument_attribute(self, class_, key, inst): + pass + + def post_configure_attribute(self, class_, key, inst): + pass + + def install_descriptor(self, class_, key, inst): + setattr(class_, key, inst) + + def uninstall_descriptor(self, class_, key): + delattr(class_, key) + + def install_member(self, class_, key, implementation): + setattr(class_, key, implementation) + + def uninstall_member(self, class_, key): + delattr(class_, key) + + def instrument_collection_class(self, class_, key, collection_class): + return collections.prepare_instrumentation(collection_class) + + def get_instance_dict(self, class_, instance): + return instance.__dict__ + + def initialize_instance_dict(self, class_, instance): + pass + + def install_state(self, class_, instance, state): + setattr(instance, "_default_state", state) + + def remove_state(self, class_, instance): + delattr(instance, "_default_state") + + def state_getter(self, class_): + return lambda instance: getattr(instance, "_default_state") + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) + + +class _ClassInstrumentationAdapter(ClassManager): + """Adapts a user-defined InstrumentationManager to a ClassManager.""" + + def __init__(self, class_, override): + self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + + ClassManager.__init__(self, class_) + + def manage(self): + self._adapted.manage(self.class_, self) + + def unregister(self): + self._adapted.unregister(self.class_, self) + + def manager_getter(self): + return self._adapted.manager_getter(self.class_) + + def instrument_attribute(self, key, inst, propagated=False): + ClassManager.instrument_attribute(self, key, inst, propagated) + if not propagated: + self._adapted.instrument_attribute(self.class_, key, inst) + + def post_configure_attribute(self, key): + super(_ClassInstrumentationAdapter, self).post_configure_attribute(key) + self._adapted.post_configure_attribute(self.class_, key, self[key]) + + def install_descriptor(self, key, inst): + self._adapted.install_descriptor(self.class_, key, inst) + + def uninstall_descriptor(self, key): + self._adapted.uninstall_descriptor(self.class_, key) + + def install_member(self, key, implementation): + self._adapted.install_member(self.class_, key, implementation) + + def uninstall_member(self, key): + self._adapted.uninstall_member(self.class_, key) + + def instrument_collection_class(self, key, collection_class): + return self._adapted.instrument_collection_class( + self.class_, key, collection_class + ) + + def initialize_collection(self, key, state, factory): + delegate = getattr(self._adapted, "initialize_collection", None) + if delegate: + return delegate(key, state, factory) + else: + return ClassManager.initialize_collection( + self, key, state, factory + ) + + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + self.setup_instance(instance, state) + return instance + + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + """ + if self.has_state(instance): + return False + else: + return self.setup_instance(instance) + + def setup_instance(self, instance, state=None): + self._adapted.initialize_instance_dict(self.class_, instance) + + if state is None: + state = self._state_constructor(instance, self) + + # the given instance is assumed to have no state + self._adapted.install_state(self.class_, instance, state) + return state + + def teardown_instance(self, instance): + self._adapted.remove_state(self.class_, instance) + + def has_state(self, instance): + try: + self._get_state(instance) + except orm_exc.NO_STATE: + return False + else: + return True + + def state_getter(self): + return self._get_state + + def dict_getter(self): + return self._get_dict + + +def _install_instrumented_lookups(): + """Replace global class/object management functions + with ExtendedInstrumentationRegistry implementations, which + allow multiple types of class managers to be present, + at the cost of performance. + + This function is called only by ExtendedInstrumentationRegistry + and unit tests specific to this behavior. + + The _reinstall_default_lookups() function can be called + after this one to re-establish the default functions. + + """ + _install_lookups( + dict( + instance_state=_instrumentation_factory.state_of, + instance_dict=_instrumentation_factory.dict_of, + manager_of_class=_instrumentation_factory.manager_of_class, + ) + ) + + +def _reinstall_default_lookups(): + """Restore simplified lookups.""" + _install_lookups( + dict( + instance_state=_default_state_getter, + instance_dict=_default_dict_getter, + manager_of_class=_default_manager_getter, + ) + ) + _instrumentation_factory._extended = False + + +def _install_lookups(lookups): + global instance_state, instance_dict, manager_of_class + instance_state = lookups["instance_state"] + instance_dict = lookups["instance_dict"] + manager_of_class = lookups["manager_of_class"] + orm_base.instance_state = ( + attributes.instance_state + ) = orm_instrumentation.instance_state = instance_state + orm_base.instance_dict = ( + attributes.instance_dict + ) = orm_instrumentation.instance_dict = instance_dict + orm_base.manager_of_class = ( + attributes.manager_of_class + ) = orm_instrumentation.manager_of_class = manager_of_class diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py new file mode 100644 index 0000000..cbec06a --- /dev/null +++ b/lib/sqlalchemy/ext/mutable.py @@ -0,0 +1,958 @@ +# ext/mutable.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Provide support for tracking of in-place changes to scalar values, +which are propagated into ORM change events on owning parent objects. + +.. _mutable_scalars: + +Establishing Mutability on Scalar Column Values +=============================================== + +A typical example of a "mutable" structure is a Python dictionary. +Following the example introduced in :ref:`types_toplevel`, we +begin with a custom type that marshals Python dictionaries into +JSON strings before being persisted:: + + from sqlalchemy.types import TypeDecorator, VARCHAR + import json + + class JSONEncodedDict(TypeDecorator): + "Represents an immutable structure as a json-encoded string." + + impl = VARCHAR + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + +The usage of ``json`` is only for the purposes of example. The +:mod:`sqlalchemy.ext.mutable` extension can be used +with any type whose target Python type may be mutable, including +:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc. + +When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself +tracks all parents which reference it. Below, we illustrate a simple +version of the :class:`.MutableDict` dictionary object, which applies +the :class:`.Mutable` mixin to a plain Python dictionary:: + + from sqlalchemy.ext.mutable import Mutable + + class MutableDict(Mutable, dict): + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + +The above dictionary class takes the approach of subclassing the Python +built-in ``dict`` to produce a dict +subclass which routes all mutation events through ``__setitem__``. There are +variants on this approach, such as subclassing ``UserDict.UserDict`` or +``collections.MutableMapping``; the part that's important to this example is +that the :meth:`.Mutable.changed` method is called whenever an in-place +change to the datastructure takes place. + +We also redefine the :meth:`.Mutable.coerce` method which will be used to +convert any values that are not instances of ``MutableDict``, such +as the plain dictionaries returned by the ``json`` module, into the +appropriate type. Defining this method is optional; we could just as well +created our ``JSONEncodedDict`` such that it always returns an instance +of ``MutableDict``, and additionally ensured that all calling code +uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not +overridden, any values applied to a parent object which are not instances +of the mutable type will raise a ``ValueError``. + +Our new ``MutableDict`` type offers a class method +:meth:`~.Mutable.as_mutable` which we can use within column metadata +to associate with types. This method grabs the given type object or +class and associates a listener that will detect all future mappings +of this type, applying event listening instrumentation to the mapped +attribute. Such as, with classical table metadata:: + + from sqlalchemy import Table, Column, Integer + + my_data = Table('my_data', metadata, + Column('id', Integer, primary_key=True), + Column('data', MutableDict.as_mutable(JSONEncodedDict)) + ) + +Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict`` +(if the type object was not an instance already), which will intercept any +attributes which are mapped against this type. Below we establish a simple +mapping against the ``my_data`` table:: + + from sqlalchemy import mapper + + class MyDataClass(object): + pass + + # associates mutation listeners with MyDataClass.data + mapper(MyDataClass, my_data) + +The ``MyDataClass.data`` member will now be notified of in place changes +to its value. + +There's no difference in usage when using declarative:: + + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class MyDataClass(Base): + __tablename__ = 'my_data' + id = Column(Integer, primary_key=True) + data = Column(MutableDict.as_mutable(JSONEncodedDict)) + +Any in-place changes to the ``MyDataClass.data`` member +will flag the attribute as "dirty" on the parent object:: + + >>> from sqlalchemy.orm import Session + + >>> sess = Session() + >>> m1 = MyDataClass(data={'value1':'foo'}) + >>> sess.add(m1) + >>> sess.commit() + + >>> m1.data['value1'] = 'bar' + >>> assert m1 in sess.dirty + True + +The ``MutableDict`` can be associated with all future instances +of ``JSONEncodedDict`` in one step, using +:meth:`~.Mutable.associate_with`. This is similar to +:meth:`~.Mutable.as_mutable` except it will intercept all occurrences +of ``MutableDict`` in all mappings unconditionally, without +the need to declare it individually:: + + MutableDict.associate_with(JSONEncodedDict) + + class MyDataClass(Base): + __tablename__ = 'my_data' + id = Column(Integer, primary_key=True) + data = Column(JSONEncodedDict) + + +Supporting Pickling +-------------------- + +The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the +placement of a ``weakref.WeakKeyDictionary`` upon the value object, which +stores a mapping of parent mapped objects keyed to the attribute name under +which they are associated with this value. ``WeakKeyDictionary`` objects are +not picklable, due to the fact that they contain weakrefs and function +callbacks. In our case, this is a good thing, since if this dictionary were +picklable, it could lead to an excessively large pickle size for our value +objects that are pickled by themselves outside of the context of the parent. +The developer responsibility here is only to provide a ``__getstate__`` method +that excludes the :meth:`~MutableBase._parents` collection from the pickle +stream:: + + class MyMutableType(Mutable): + def __getstate__(self): + d = self.__dict__.copy() + d.pop('_parents', None) + return d + +With our dictionary example, we need to return the contents of the dict itself +(and also restore them on __setstate__):: + + class MutableDict(Mutable, dict): + # .... + + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) + +In the case that our mutable value object is pickled as it is attached to one +or more parent objects that are also part of the pickle, the :class:`.Mutable` +mixin will re-establish the :attr:`.Mutable._parents` collection on each value +object as the owning parents themselves are unpickled. + +Receiving Events +---------------- + +The :meth:`.AttributeEvents.modified` event handler may be used to receive +an event when a mutable scalar emits a change event. This event handler +is called when the :func:`.attributes.flag_modified` function is called +from within the mutable extension:: + + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import event + + Base = declarative_base() + + class MyDataClass(Base): + __tablename__ = 'my_data' + id = Column(Integer, primary_key=True) + data = Column(MutableDict.as_mutable(JSONEncodedDict)) + + @event.listens_for(MyDataClass.data, "modified") + def modified_json(instance): + print("json value modified:", instance.data) + +.. _mutable_composites: + +Establishing Mutability on Composites +===================================== + +Composites are a special ORM feature which allow a single scalar attribute to +be assigned an object value which represents information "composed" from one +or more columns from the underlying mapped table. The usual example is that of +a geometric "point", and is introduced in :ref:`mapper_composite`. + +As is the case with :class:`.Mutable`, the user-defined composite class +subclasses :class:`.MutableComposite` as a mixin, and detects and delivers +change events to its parents via the :meth:`.MutableComposite.changed` method. +In the case of a composite class, the detection is usually via the usage of +Python descriptors (i.e. ``@property``), or alternatively via the special +Python method ``__setattr__()``. Below we expand upon the ``Point`` class +introduced in :ref:`mapper_composite` to subclass :class:`.MutableComposite` +and to also route attribute set events via ``__setattr__`` to the +:meth:`.MutableComposite.changed` method:: + + from sqlalchemy.ext.mutable import MutableComposite + + class Point(MutableComposite): + def __init__(self, x, y): + self.x = x + self.y = y + + def __setattr__(self, key, value): + "Intercept set events" + + # set the attribute + object.__setattr__(self, key, value) + + # alert all parents to the change + self.changed() + + def __composite_values__(self): + return self.x, self.y + + def __eq__(self, other): + return isinstance(other, Point) and \ + other.x == self.x and \ + other.y == self.y + + def __ne__(self, other): + return not self.__eq__(other) + +The :class:`.MutableComposite` class uses a Python metaclass to automatically +establish listeners for any usage of :func:`_orm.composite` that specifies our +``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` class, +listeners are established which will route change events from ``Point`` +objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: + + from sqlalchemy.orm import composite, mapper + from sqlalchemy import Table, Column + + vertices = Table('vertices', metadata, + Column('id', Integer, primary_key=True), + Column('x1', Integer), + Column('y1', Integer), + Column('x2', Integer), + Column('y2', Integer), + ) + + class Vertex(object): + pass + + mapper(Vertex, vertices, properties={ + 'start': composite(Point, vertices.c.x1, vertices.c.y1), + 'end': composite(Point, vertices.c.x2, vertices.c.y2) + }) + +Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members +will flag the attribute as "dirty" on the parent object:: + + >>> from sqlalchemy.orm import Session + + >>> sess = Session() + >>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15)) + >>> sess.add(v1) + >>> sess.commit() + + >>> v1.end.x = 8 + >>> assert v1 in sess.dirty + True + +Coercing Mutable Composites +--------------------------- + +The :meth:`.MutableBase.coerce` method is also supported on composite types. +In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce` +method is only called for attribute set operations, not load operations. +Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent +to using a :func:`.validates` validation routine for all attributes which +make use of the custom composite type:: + + class Point(MutableComposite): + # other Point methods + # ... + + def coerce(cls, key, value): + if isinstance(value, tuple): + value = Point(*value) + elif not isinstance(value, Point): + raise ValueError("tuple or Point expected") + return value + +Supporting Pickling +-------------------- + +As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper +class uses a ``weakref.WeakKeyDictionary`` available via the +:meth:`MutableBase._parents` attribute which isn't picklable. If we need to +pickle instances of ``Point`` or its owning class ``Vertex``, we at least need +to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary. +Below we define both a ``__getstate__`` and a ``__setstate__`` that package up +the minimal form of our ``Point`` class:: + + class Point(MutableComposite): + # ... + + def __getstate__(self): + return self.x, self.y + + def __setstate__(self, state): + self.x, self.y = state + +As with :class:`.Mutable`, the :class:`.MutableComposite` augments the +pickling process of the parent's object-relational state so that the +:meth:`MutableBase._parents` collection is restored to all ``Point`` objects. + +""" +from collections import defaultdict +import weakref + +from .. import event +from .. import inspect +from .. import types +from ..orm import Mapper +from ..orm import mapper +from ..orm.attributes import flag_modified +from ..sql.base import SchemaEventTarget +from ..util import memoized_property + + +class MutableBase(object): + """Common base class to :class:`.Mutable` + and :class:`.MutableComposite`. + + """ + + @memoized_property + def _parents(self): + """Dictionary of parent object's :class:`.InstanceState`->attribute + name on the parent. + + This attribute is a so-called "memoized" property. It initializes + itself with a new ``weakref.WeakKeyDictionary`` the first time + it is accessed, returning the same object upon subsequent access. + + .. versionchanged:: 1.4 the :class:`.InstanceState` is now used + as the key in the weak dictionary rather than the instance + itself. + + """ + + return weakref.WeakKeyDictionary() + + @classmethod + def coerce(cls, key, value): + """Given a value, coerce it into the target type. + + Can be overridden by custom subclasses to coerce incoming + data into a particular type. + + By default, raises ``ValueError``. + + This method is called in different scenarios depending on if + the parent class is of type :class:`.Mutable` or of type + :class:`.MutableComposite`. In the case of the former, it is called + for both attribute-set operations as well as during ORM loading + operations. For the latter, it is only called during attribute-set + operations; the mechanics of the :func:`.composite` construct + handle coercion during load operations. + + + :param key: string name of the ORM-mapped attribute being set. + :param value: the incoming value. + :return: the method should return the coerced value, or raise + ``ValueError`` if the coercion cannot be completed. + + """ + if value is None: + return None + msg = "Attribute '%s' does not accept objects of type %s" + raise ValueError(msg % (key, type(value))) + + @classmethod + def _get_listen_keys(cls, attribute): + """Given a descriptor attribute, return a ``set()`` of the attribute + keys which indicate a change in the state of this attribute. + + This is normally just ``set([attribute.key])``, but can be overridden + to provide for additional keys. E.g. a :class:`.MutableComposite` + augments this set with the attribute keys associated with the columns + that comprise the composite value. + + This collection is consulted in the case of intercepting the + :meth:`.InstanceEvents.refresh` and + :meth:`.InstanceEvents.refresh_flush` events, which pass along a list + of attribute names that have been refreshed; the list is compared + against this set to determine if action needs to be taken. + + .. versionadded:: 1.0.5 + + """ + return {attribute.key} + + @classmethod + def _listen_on_attribute(cls, attribute, coerce, parent_cls): + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + key = attribute.key + if parent_cls is not attribute.class_: + return + + # rely on "propagate" here + parent_cls = attribute.class_ + + listen_keys = cls._get_listen_keys(attribute) + + def load(state, *args): + """Listen for objects loaded or refreshed. + + Wrap the target data member's value with + ``Mutable``. + + """ + val = state.dict.get(key, None) + if val is not None: + if coerce: + val = cls.coerce(key, val) + state.dict[key] = val + val._parents[state] = key + + def load_attrs(state, ctx, attrs): + if not attrs or listen_keys.intersection(attrs): + load(state) + + def set_(target, value, oldvalue, initiator): + """Listen for set/replace events on the target + data member. + + Establish a weak reference to the parent object + on the incoming value, remove it for the one + outgoing. + + """ + if value is oldvalue: + return value + + if not isinstance(value, cls): + value = cls.coerce(key, value) + if value is not None: + value._parents[target] = key + if isinstance(oldvalue, cls): + oldvalue._parents.pop(inspect(target), None) + return value + + def pickle(state, state_dict): + val = state.dict.get(key, None) + if val is not None: + if "ext.mutable.values" not in state_dict: + state_dict["ext.mutable.values"] = defaultdict(list) + state_dict["ext.mutable.values"][key].append(val) + + def unpickle(state, state_dict): + if "ext.mutable.values" in state_dict: + collection = state_dict["ext.mutable.values"] + if isinstance(collection, list): + # legacy format + for val in collection: + val._parents[state] = key + else: + for val in state_dict["ext.mutable.values"][key]: + val._parents[state] = key + + event.listen(parent_cls, "load", load, raw=True, propagate=True) + event.listen( + parent_cls, "refresh", load_attrs, raw=True, propagate=True + ) + event.listen( + parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True + ) + event.listen( + attribute, "set", set_, raw=True, retval=True, propagate=True + ) + event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True) + event.listen( + parent_cls, "unpickle", unpickle, raw=True, propagate=True + ) + + +class Mutable(MutableBase): + """Mixin that defines transparent propagation of change + events to a parent object. + + See the example in :ref:`mutable_scalars` for usage information. + + """ + + def changed(self): + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + flag_modified(parent.obj(), key) + + @classmethod + def associate_with_attribute(cls, attribute): + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + cls._listen_on_attribute(attribute, True, attribute.class_) + + @classmethod + def associate_with(cls, sqltype): + """Associate this wrapper with all future mapped columns + of the given type. + + This is a convenience method that calls + ``associate_with_attribute`` automatically. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.associate_with` for types that are permanent to an + application, not with ad-hoc types else this will cause unbounded + growth in memory usage. + + """ + + def listen_for_type(mapper, class_): + if mapper.non_primary: + return + for prop in mapper.column_attrs: + if isinstance(prop.columns[0].type, sqltype): + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(mapper, "mapper_configured", listen_for_type) + + @classmethod + def as_mutable(cls, sqltype): + """Associate a SQL type with this mutable Python type. + + This establishes listeners that will detect ORM mappings against + the given type, adding mutation event trackers to those mappings. + + The type is returned, unconditionally as an instance, so that + :meth:`.as_mutable` can be used inline:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('data', MyMutableType.as_mutable(PickleType)) + ) + + Note that the returned type is always an instance, even if a class + is given, and that only columns which are declared specifically with + that type instance receive additional instrumentation. + + To associate a particular mutable type with all occurrences of a + particular type, use the :meth:`.Mutable.associate_with` classmethod + of the particular :class:`.Mutable` subclass to establish a global + association. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.as_mutable` for types that are permanent to an application, + not with ad-hoc types else this will cause unbounded growth + in memory usage. + + """ + sqltype = types.to_instance(sqltype) + + # a SchemaType will be copied when the Column is copied, + # and we'll lose our ability to link that type back to the original. + # so track our original type w/ columns + if isinstance(sqltype, SchemaEventTarget): + + @event.listens_for(sqltype, "before_parent_attach") + def _add_column_memo(sqltyp, parent): + parent.info["_ext_mutable_orig_type"] = sqltyp + + schema_event_check = True + else: + schema_event_check = False + + def listen_for_type(mapper, class_): + if mapper.non_primary: + return + for prop in mapper.column_attrs: + if ( + schema_event_check + and hasattr(prop.expression, "info") + and prop.expression.info.get("_ext_mutable_orig_type") + is sqltype + ) or (prop.columns[0].type is sqltype): + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(mapper, "mapper_configured", listen_for_type) + + return sqltype + + +class MutableComposite(MutableBase): + """Mixin that defines transparent propagation of change + events on a SQLAlchemy "composite" object to its + owning parent or parents. + + See the example in :ref:`mutable_composites` for usage information. + + """ + + @classmethod + def _get_listen_keys(cls, attribute): + return {attribute.key}.union(attribute.property._attribute_keys) + + def changed(self): + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + + prop = parent.mapper.get_property(key) + for value, attr_name in zip( + self.__composite_values__(), prop._attribute_keys + ): + setattr(parent.obj(), attr_name, value) + + +def _setup_composite_listener(): + def _listen_for_type(mapper, class_): + for prop in mapper.iterate_properties: + if ( + hasattr(prop, "composite_class") + and isinstance(prop.composite_class, type) + and issubclass(prop.composite_class, MutableComposite) + ): + prop.composite_class._listen_on_attribute( + getattr(class_, prop.key), False, class_ + ) + + if not event.contains(Mapper, "mapper_configured", _listen_for_type): + event.listen(Mapper, "mapper_configured", _listen_for_type) + + +_setup_composite_listener() + + +class MutableDict(Mutable, dict): + """A dictionary type that implements :class:`.Mutable`. + + The :class:`.MutableDict` object implements a dictionary that will + emit change events to the underlying mapping when the contents of + the dictionary are altered, including when values are added or removed. + + Note that :class:`.MutableDict` does **not** apply mutable tracking to the + *values themselves* inside the dictionary. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + dictionary structure, such as a JSON structure. To support this use case, + build a subclass of :class:`.MutableDict` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. seealso:: + + :class:`.MutableList` + + :class:`.MutableSet` + + """ + + def __setitem__(self, key, value): + """Detect dictionary set events and emit change events.""" + dict.__setitem__(self, key, value) + self.changed() + + def setdefault(self, key, value): + result = dict.setdefault(self, key, value) + self.changed() + return result + + def __delitem__(self, key): + """Detect dictionary del events and emit change events.""" + dict.__delitem__(self, key) + self.changed() + + def update(self, *a, **kw): + dict.update(self, *a, **kw) + self.changed() + + def pop(self, *arg): + result = dict.pop(self, *arg) + self.changed() + return result + + def popitem(self): + result = dict.popitem(self) + self.changed() + return result + + def clear(self): + dict.clear(self) + self.changed() + + @classmethod + def coerce(cls, key, value): + """Convert plain dictionary to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, dict): + return cls(value) + return Mutable.coerce(key, value) + else: + return value + + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) + + +class MutableList(Mutable, list): + """A list type that implements :class:`.Mutable`. + + The :class:`.MutableList` object implements a list that will + emit change events to the underlying mapping when the contents of + the list are altered, including when values are added or removed. + + Note that :class:`.MutableList` does **not** apply mutable tracking to the + *values themselves* inside the list. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + mutable structure, such as a JSON structure. To support this use case, + build a subclass of :class:`.MutableList` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. versionadded:: 1.1 + + .. seealso:: + + :class:`.MutableDict` + + :class:`.MutableSet` + + """ + + def __reduce_ex__(self, proto): + return (self.__class__, (list(self),)) + + # needed for backwards compatibility with + # older pickles + def __setstate__(self, state): + self[:] = state + + def __setitem__(self, index, value): + """Detect list set events and emit change events.""" + list.__setitem__(self, index, value) + self.changed() + + def __setslice__(self, start, end, value): + """Detect list set events and emit change events.""" + list.__setslice__(self, start, end, value) + self.changed() + + def __delitem__(self, index): + """Detect list del events and emit change events.""" + list.__delitem__(self, index) + self.changed() + + def __delslice__(self, start, end): + """Detect list del events and emit change events.""" + list.__delslice__(self, start, end) + self.changed() + + def pop(self, *arg): + result = list.pop(self, *arg) + self.changed() + return result + + def append(self, x): + list.append(self, x) + self.changed() + + def extend(self, x): + list.extend(self, x) + self.changed() + + def __iadd__(self, x): + self.extend(x) + return self + + def insert(self, i, x): + list.insert(self, i, x) + self.changed() + + def remove(self, i): + list.remove(self, i) + self.changed() + + def clear(self): + list.clear(self) + self.changed() + + def sort(self, **kw): + list.sort(self, **kw) + self.changed() + + def reverse(self): + list.reverse(self) + self.changed() + + @classmethod + def coerce(cls, index, value): + """Convert plain list to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, list): + return cls(value) + return Mutable.coerce(index, value) + else: + return value + + +class MutableSet(Mutable, set): + """A set type that implements :class:`.Mutable`. + + The :class:`.MutableSet` object implements a set that will + emit change events to the underlying mapping when the contents of + the set are altered, including when values are added or removed. + + Note that :class:`.MutableSet` does **not** apply mutable tracking to the + *values themselves* inside the set. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + mutable structure. To support this use case, + build a subclass of :class:`.MutableSet` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. versionadded:: 1.1 + + .. seealso:: + + :class:`.MutableDict` + + :class:`.MutableList` + + + """ + + def update(self, *arg): + set.update(self, *arg) + self.changed() + + def intersection_update(self, *arg): + set.intersection_update(self, *arg) + self.changed() + + def difference_update(self, *arg): + set.difference_update(self, *arg) + self.changed() + + def symmetric_difference_update(self, *arg): + set.symmetric_difference_update(self, *arg) + self.changed() + + def __ior__(self, other): + self.update(other) + return self + + def __iand__(self, other): + self.intersection_update(other) + return self + + def __ixor__(self, other): + self.symmetric_difference_update(other) + return self + + def __isub__(self, other): + self.difference_update(other) + return self + + def add(self, elem): + set.add(self, elem) + self.changed() + + def remove(self, elem): + set.remove(self, elem) + self.changed() + + def discard(self, elem): + set.discard(self, elem) + self.changed() + + def pop(self, *arg): + result = set.pop(self, *arg) + self.changed() + return result + + def clear(self): + set.clear(self) + self.changed() + + @classmethod + def coerce(cls, index, value): + """Convert plain set to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, set): + return cls(value) + return Mutable.coerce(index, value) + else: + return value + + def __getstate__(self): + return set(self) + + def __setstate__(self, state): + self.update(state) + + def __reduce_ex__(self, proto): + return (self.__class__, (list(self),)) diff --git a/lib/sqlalchemy/ext/mypy/__init__.py b/lib/sqlalchemy/ext/mypy/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/__init__.py diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py new file mode 100644 index 0000000..99be194 --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -0,0 +1,299 @@ +# ext/mypy/apply.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from typing import List +from typing import Optional +from typing import Union + +from mypy.nodes import ARG_NAMED_OPT +from mypy.nodes import Argument +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import MDEF +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import add_method_to_class +from mypy.types import AnyType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneTyp +from mypy.types import ProperType +from mypy.types import TypeOfAny +from mypy.types import UnboundType +from mypy.types import UnionType + +from . import infer +from . import util +from .names import NAMED_TYPE_SQLA_MAPPED + + +def apply_mypy_mapped_attr( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + item: Union[NameExpr, StrExpr], + attributes: List[util.SQLAlchemyAttribute], +) -> None: + if isinstance(item, NameExpr): + name = item.name + elif isinstance(item, StrExpr): + name = item.value + else: + return None + + for stmt in cls.defs.body: + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name == name + ): + break + else: + util.fail(api, "Can't find mapped attribute {}".format(name), cls) + return None + + if stmt.type is None: + util.fail( + api, + "Statement linked from _mypy_mapped_attrs has no " + "typing information", + stmt, + ) + return None + + left_hand_explicit_type = get_proper_type(stmt.type) + assert isinstance( + left_hand_explicit_type, (Instance, UnionType, UnboundType) + ) + + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=item.line, + column=item.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + + apply_type_to_mapped_statement( + api, stmt, stmt.lvalues[0], left_hand_explicit_type, None + ) + + +def re_apply_declarative_assignments( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """For multiple class passes, re-apply our left-hand side types as mypy + seems to reset them in place. + + """ + mapped_attr_lookup = {attr.name: attr for attr in attributes} + update_cls_metadata = False + + for stmt in cls.defs.body: + # for a re-apply, all of our statements are AssignmentStmt; + # @declared_attr calls will have been converted and this + # currently seems to be preserved by mypy (but who knows if this + # will change). + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name in mapped_attr_lookup + and isinstance(stmt.lvalues[0].node, Var) + ): + + left_node = stmt.lvalues[0].node + python_type_for_type = mapped_attr_lookup[ + stmt.lvalues[0].name + ].type + + left_node_proper_type = get_proper_type(left_node.type) + + # if we have scanned an UnboundType and now there's a more + # specific type than UnboundType, call the re-scan so we + # can get that set up correctly + if ( + isinstance(python_type_for_type, UnboundType) + and not isinstance(left_node_proper_type, UnboundType) + and ( + isinstance(stmt.rvalue, CallExpr) + and isinstance(stmt.rvalue.callee, MemberExpr) + and isinstance(stmt.rvalue.callee.expr, NameExpr) + and stmt.rvalue.callee.expr.node is not None + and stmt.rvalue.callee.expr.node.fullname + == NAMED_TYPE_SQLA_MAPPED + and stmt.rvalue.callee.name == "_empty_constructor" + and isinstance(stmt.rvalue.args[0], CallExpr) + and isinstance(stmt.rvalue.args[0].callee, RefExpr) + ) + ): + + python_type_for_type = ( + infer.infer_type_from_right_hand_nameexpr( + api, + stmt, + left_node, + left_node_proper_type, + stmt.rvalue.args[0].callee, + ) + ) + + if python_type_for_type is None or isinstance( + python_type_for_type, UnboundType + ): + continue + + # update the SQLAlchemyAttribute with the better information + mapped_attr_lookup[ + stmt.lvalues[0].name + ].type = python_type_for_type + + update_cls_metadata = True + + if python_type_for_type is not None: + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, [python_type_for_type] + ) + + if update_cls_metadata: + util.set_mapped_attributes(cls.info, attributes) + + +def apply_type_to_mapped_statement( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + lvalue: NameExpr, + left_hand_explicit_type: Optional[ProperType], + python_type_for_type: Optional[ProperType], +) -> None: + """Apply the Mapped[<type>] annotation and right hand object to a + declarative assignment statement. + + This converts a Python declarative class statement such as:: + + class User(Base): + # ... + + attrname = Column(Integer) + + To one that describes the final Python behavior to Mypy:: + + class User(Base): + # ... + + attrname : Mapped[Optional[int]] = <meaningless temp node> + + """ + left_node = lvalue.node + assert isinstance(left_node, Var) + + if left_hand_explicit_type is not None: + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] + ) + else: + lvalue.is_inferred_def = False + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, + [] if python_type_for_type is None else [python_type_for_type], + ) + + # so to have it skip the right side totally, we can do this: + # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # however, if we instead manufacture a new node that uses the old + # one, then we can still get type checking for the call itself, + # e.g. the Column, relationship() call, etc. + + # rewrite the node as: + # <attr> : Mapped[<typ>] = + # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>) + # the original right-hand side is maintained so it gets type checked + # internally + stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue) + + +def add_additional_orm_attributes( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Apply __init__, __table__ and other attributes to the mapped class.""" + + info = util.info_for_cls(cls, api) + + if info is None: + return + + is_base = util.get_is_base(info) + + if "__init__" not in info.names and not is_base: + mapped_attr_names = {attr.name: attr.type for attr in attributes} + + for base in info.mro[1:-1]: + if "sqlalchemy" not in info.metadata: + continue + + base_cls_attributes = util.get_mapped_attributes(base, api) + if base_cls_attributes is None: + continue + + for attr in base_cls_attributes: + mapped_attr_names.setdefault(attr.name, attr.type) + + arguments = [] + for name, typ in mapped_attr_names.items(): + if typ is None: + typ = AnyType(TypeOfAny.special_form) + arguments.append( + Argument( + variable=Var(name, typ), + type_annotation=typ, + initializer=TempNode(typ), + kind=ARG_NAMED_OPT, + ) + ) + + add_method_to_class(api, cls, "__init__", arguments, NoneTyp()) + + if "__table__" not in info.names and util.get_has_table(info): + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.sql.schema.Table", "__table__" + ) + if not is_base: + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__" + ) + + +def _apply_placeholder_attr_to_class( + api: SemanticAnalyzerPluginInterface, + cls: ClassDef, + qualified_name: str, + attrname: str, +) -> None: + sym = api.lookup_fully_qualified_or_none(qualified_name) + if sym: + assert isinstance(sym.node, TypeInfo) + type_: ProperType = Instance(sym.node, []) + else: + type_ = AnyType(TypeOfAny.special_form) + var = Var(attrname) + var._fullname = cls.fullname + "." + attrname + var.info = cls.info + var.type = type_ + cls.info.names[attrname] = SymbolTableNode(MDEF, var) diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py new file mode 100644 index 0000000..c33c30e --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -0,0 +1,516 @@ +# ext/mypy/decl_class.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from typing import List +from typing import Optional +from typing import Union + +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import Decorator +from mypy.nodes import LambdaExpr +from mypy.nodes import ListExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import PlaceholderNode +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolNode +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import Type +from mypy.types import TypeOfAny +from mypy.types import UnboundType +from mypy.types import UnionType + +from . import apply +from . import infer +from . import names +from . import util + + +def scan_declarative_assignments_and_apply_types( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + is_mixin_scan: bool = False, +) -> Optional[List[util.SQLAlchemyAttribute]]: + + info = util.info_for_cls(cls, api) + + if info is None: + # this can occur during cached passes + return None + elif cls.fullname.startswith("builtins"): + return None + + mapped_attributes: Optional[ + List[util.SQLAlchemyAttribute] + ] = util.get_mapped_attributes(info, api) + + # used by assign.add_additional_orm_attributes among others + util.establish_as_sqlalchemy(info) + + if mapped_attributes is not None: + # ensure that a class that's mapped is always picked up by + # its mapped() decorator or declarative metaclass before + # it would be detected as an unmapped mixin class + + if not is_mixin_scan: + # mypy can call us more than once. it then *may* have reset the + # left hand side of everything, but not the right that we removed, + # removing our ability to re-scan. but we have the types + # here, so lets re-apply them, or if we have an UnboundType, + # we can re-scan + + apply.re_apply_declarative_assignments(cls, api, mapped_attributes) + + return mapped_attributes + + mapped_attributes = [] + + if not cls.defs.body: + # when we get a mixin class from another file, the body is + # empty (!) but the names are in the symbol table. so use that. + + for sym_name, sym in info.names.items(): + _scan_symbol_table_entry( + cls, api, sym_name, sym, mapped_attributes + ) + else: + for stmt in util.flatten_typechecking(cls.defs.body): + if isinstance(stmt, AssignmentStmt): + _scan_declarative_assignment_stmt( + cls, api, stmt, mapped_attributes + ) + elif isinstance(stmt, Decorator): + _scan_declarative_decorator_stmt( + cls, api, stmt, mapped_attributes + ) + _scan_for_mapped_bases(cls, api) + + if not is_mixin_scan: + apply.add_additional_orm_attributes(cls, api, mapped_attributes) + + util.set_mapped_attributes(info, mapped_attributes) + + return mapped_attributes + + +def _scan_symbol_table_entry( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + name: str, + value: SymbolTableNode, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from a SymbolTableNode that's in the + type.names dictionary. + + """ + value_type = get_proper_type(value.type) + if not isinstance(value_type, Instance): + return + + left_hand_explicit_type = None + type_id = names.type_id_for_named_node(value_type.type) + # type_id = names._type_id_for_unbound_type(value.type.type, cls, api) + + err = False + + # TODO: this is nearly the same logic as that of + # _scan_declarative_decorator_stmt, likely can be merged + if type_id in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + }: + if value_type.args: + left_hand_explicit_type = get_proper_type(value_type.args[0]) + else: + err = True + elif type_id is names.COLUMN: + if not value_type.args: + err = True + else: + typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type( + value_type.args[0] + ) + if isinstance(typeengine_arg, Instance): + typeengine_arg = typeengine_arg.type + + if isinstance(typeengine_arg, (UnboundType, TypeInfo)): + sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) + if sym is not None and isinstance(sym.node, TypeInfo): + if names.has_base_type_id(sym.node, names.TYPEENGINE): + + left_hand_explicit_type = UnionType( + [ + infer.extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + value_type, + ) + + if err: + msg = ( + "Can't infer type from attribute {} on class {}. " + "please specify a return type from this function that is " + "one of: Mapped[<python type>], relationship[<target class>], " + "Column[<TypeEngine>], MapperProperty[<python type>]" + ) + util.fail(api, msg.format(name, cls.name), cls) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + if left_hand_explicit_type is not None: + assert value.node is not None + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=value.node.line, + column=value.node.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + + +def _scan_declarative_decorator_stmt( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + stmt: Decorator, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from a @declared_attr in a declarative + class. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + @declared_attr + def updated_at(cls) -> Column[DateTime]: + return Column(DateTime) + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + updated_at: Mapped[Optional[datetime.datetime]] + + """ + for dec in stmt.decorators: + if ( + isinstance(dec, (NameExpr, MemberExpr, SymbolNode)) + and names.type_id_for_named_node(dec) is names.DECLARED_ATTR + ): + break + else: + return + + dec_index = cls.defs.body.index(stmt) + + left_hand_explicit_type: Optional[ProperType] = None + + if util.name_is_dunder(stmt.name): + # for dunder names like __table_args__, __tablename__, + # __mapper_args__ etc., rewrite these as simple assignment + # statements; otherwise mypy doesn't like if the decorated + # function has an annotation like ``cls: Type[Foo]`` because + # it isn't @classmethod + any_ = AnyType(TypeOfAny.special_form) + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + new_stmt = AssignmentStmt([left_node], TempNode(any_)) + new_stmt.type = left_node.node.type + cls.defs.body[dec_index] = new_stmt + return + elif isinstance(stmt.func.type, CallableType): + func_type = stmt.func.type.ret_type + if isinstance(func_type, UnboundType): + type_id = names.type_id_for_unbound_type(func_type, cls, api) + else: + # this does not seem to occur unless the type argument is + # incorrect + return + + if ( + type_id + in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + } + and func_type.args + ): + left_hand_explicit_type = get_proper_type(func_type.args[0]) + elif type_id is names.COLUMN and func_type.args: + typeengine_arg = func_type.args[0] + if isinstance(typeengine_arg, UnboundType): + sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) + if sym is not None and isinstance(sym.node, TypeInfo): + if names.has_base_type_id(sym.node, names.TYPEENGINE): + left_hand_explicit_type = UnionType( + [ + infer.extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + func_type, + ) + + if left_hand_explicit_type is None: + # no type on the decorated function. our option here is to + # dig into the function body and get the return type, but they + # should just have an annotation. + msg = ( + "Can't infer type from @declared_attr on function '{}'; " + "please specify a return type from this function that is " + "one of: Mapped[<python type>], relationship[<target class>], " + "Column[<TypeEngine>], MapperProperty[<python type>]" + ) + util.fail(api, msg.format(stmt.var.name), stmt) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + + # totally feeling around in the dark here as I don't totally understand + # the significance of UnboundType. It seems to be something that is + # not going to do what's expected when it is applied as the type of + # an AssignmentStatement. So do a feeling-around-in-the-dark version + # of converting it to the regular Instance/TypeInfo/UnionType structures + # we see everywhere else. + if isinstance(left_hand_explicit_type, UnboundType): + left_hand_explicit_type = get_proper_type( + util.unbound_to_instance(api, left_hand_explicit_type) + ) + + left_node.node.type = api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] + ) + + # this will ignore the rvalue entirely + # rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # rewrite the node as: + # <attr> : Mapped[<typ>] = + # _sa_Mapped._empty_constructor(lambda: <function body>) + # the function body is maintained so it gets type checked internally + rvalue = util.expr_to_mapped_constructor( + LambdaExpr(stmt.func.arguments, stmt.func.body) + ) + + new_stmt = AssignmentStmt([left_node], rvalue) + new_stmt.type = left_node.node.type + + attributes.append( + util.SQLAlchemyAttribute( + name=left_node.name, + line=stmt.line, + column=stmt.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + cls.defs.body[dec_index] = new_stmt + + +def _scan_declarative_assignment_stmt( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from an assignment statement in a + declarative class. + + """ + lvalue = stmt.lvalues[0] + if not isinstance(lvalue, NameExpr): + return + + sym = cls.info.names.get(lvalue.name) + + # this establishes that semantic analysis has taken place, which + # means the nodes are populated and we are called from an appropriate + # hook. + assert sym is not None + node = sym.node + + if isinstance(node, PlaceholderNode): + return + + assert node is lvalue.node + assert isinstance(node, Var) + + if node.name == "__abstract__": + if api.parse_bool(stmt.rvalue) is True: + util.set_is_base(cls.info) + return + elif node.name == "__tablename__": + util.set_has_table(cls.info) + elif node.name.startswith("__"): + return + elif node.name == "_mypy_mapped_attrs": + if not isinstance(stmt.rvalue, ListExpr): + util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt) + else: + for item in stmt.rvalue.items: + if isinstance(item, (NameExpr, StrExpr)): + apply.apply_mypy_mapped_attr(cls, api, item, attributes) + + left_hand_mapped_type: Optional[Type] = None + left_hand_explicit_type: Optional[ProperType] = None + + if node.is_inferred or node.type is None: + if isinstance(stmt.type, UnboundType): + # look for an explicit Mapped[] type annotation on the left + # side with nothing on the right + + # print(stmt.type) + # Mapped?[Optional?[A?]] + + left_hand_explicit_type = stmt.type + + if stmt.type.name == "Mapped": + mapped_sym = api.lookup_qualified("Mapped", cls) + if ( + mapped_sym is not None + and mapped_sym.node is not None + and names.type_id_for_named_node(mapped_sym.node) + is names.MAPPED + ): + left_hand_explicit_type = get_proper_type( + stmt.type.args[0] + ) + left_hand_mapped_type = stmt.type + + # TODO: do we need to convert from unbound for this case? + # left_hand_explicit_type = util._unbound_to_instance( + # api, left_hand_explicit_type + # ) + else: + node_type = get_proper_type(node.type) + if ( + isinstance(node_type, Instance) + and names.type_id_for_named_node(node_type.type) is names.MAPPED + ): + # print(node.type) + # sqlalchemy.orm.attributes.Mapped[<python type>] + left_hand_explicit_type = get_proper_type(node_type.args[0]) + left_hand_mapped_type = node_type + else: + # print(node.type) + # <python type> + left_hand_explicit_type = node_type + left_hand_mapped_type = None + + if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None: + # annotation without assignment and Mapped is present + # as type annotation + # equivalent to using _infer_type_from_left_hand_type_only. + + python_type_for_type = left_hand_explicit_type + elif isinstance(stmt.rvalue, CallExpr) and isinstance( + stmt.rvalue.callee, RefExpr + ): + + python_type_for_type = infer.infer_type_from_right_hand_nameexpr( + api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee + ) + + if python_type_for_type is None: + return + + else: + return + + assert python_type_for_type is not None + + attributes.append( + util.SQLAlchemyAttribute( + name=node.name, + line=stmt.line, + column=stmt.column, + typ=python_type_for_type, + info=cls.info, + ) + ) + + apply.apply_type_to_mapped_statement( + api, + stmt, + lvalue, + left_hand_explicit_type, + python_type_for_type, + ) + + +def _scan_for_mapped_bases( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, +) -> None: + """Given a class, iterate through its superclass hierarchy to find + all other classes that are considered as ORM-significant. + + Locates non-mapped mixins and scans them for mapped attributes to be + applied to subclasses. + + """ + + info = util.info_for_cls(cls, api) + + if info is None: + return + + for base_info in info.mro[1:-1]: + if base_info.fullname.startswith("builtins"): + continue + + # scan each base for mapped attributes. if they are not already + # scanned (but have all their type info), that means they are unmapped + # mixins + scan_declarative_assignments_and_apply_types( + base_info.defn, api, is_mixin_scan=True + ) diff --git a/lib/sqlalchemy/ext/mypy/infer.py b/lib/sqlalchemy/ext/mypy/infer.py new file mode 100644 index 0000000..f88a960 --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/infer.py @@ -0,0 +1,556 @@ +# ext/mypy/infer.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from typing import Optional +from typing import Sequence + +from mypy.maptype import map_instance_to_supertype +from mypy.messages import format_type +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import LambdaExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.subtypes import is_subtype +from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import TypeOfAny +from mypy.types import UnionType + +from . import names +from . import util + + +def infer_type_from_right_hand_nameexpr( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + + type_id = names.type_id_for_callee(infer_from_right_side) + + if type_id is None: + return None + elif type_id is names.COLUMN: + python_type_for_type = _infer_type_from_decl_column( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.RELATIONSHIP: + python_type_for_type = _infer_type_from_relationship( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.COLUMN_PROPERTY: + python_type_for_type = _infer_type_from_decl_column_property( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.SYNONYM_PROPERTY: + python_type_for_type = infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif type_id is names.COMPOSITE_PROPERTY: + python_type_for_type = _infer_type_from_decl_composite_property( + api, stmt, node, left_hand_explicit_type + ) + else: + return None + + return python_type_for_type + + +def _infer_type_from_relationship( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a relationship. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + addresses = relationship(Address, uselist=True) + + order: Mapped["Order"] = relationship("Order") + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + addresses: Mapped[List[Address]] + + order: Mapped["Order"] + + """ + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type: Optional[ProperType] = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + # type + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + + # other cases not covered - an error message directs the user + # to set an explicit type annotation + # + # node.type == str, it's a string + # if isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, Var + # ) + # points to a type + # isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, TypeAlias + # ) + # string expression + # isinstance(target_cls_arg, StrExpr) + + uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist") + collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg( + stmt.rvalue, "collection_class" + ) + type_is_a_collection = False + + # this can be used to determine Optional for a many-to-one + # in the same way nullable=False could be used, if we start supporting + # that. + # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin") + + if ( + uselist_arg is not None + and api.parse_bool(uselist_arg) is True + and collection_cls_arg is None + ): + type_is_a_collection = True + if python_type_for_type is not None: + python_type_for_type = api.named_type( + names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type] + ) + elif ( + uselist_arg is None or api.parse_bool(uselist_arg) is True + ) and collection_cls_arg is not None: + type_is_a_collection = True + if isinstance(collection_cls_arg, CallExpr): + collection_cls_arg = collection_cls_arg.callee + + if isinstance(collection_cls_arg, NameExpr) and isinstance( + collection_cls_arg.node, TypeInfo + ): + if python_type_for_type is not None: + # this can still be overridden by the left hand side + # within _infer_Type_from_left_and_inferred_right + python_type_for_type = Instance( + collection_cls_arg.node, [python_type_for_type] + ) + elif ( + isinstance(collection_cls_arg, NameExpr) + and isinstance(collection_cls_arg.node, FuncDef) + and collection_cls_arg.node.type is not None + ): + if python_type_for_type is not None: + # this can still be overridden by the left hand side + # within _infer_Type_from_left_and_inferred_right + + # TODO: handle mypy.types.Overloaded + if isinstance(collection_cls_arg.node.type, CallableType): + rt = get_proper_type(collection_cls_arg.node.type.ret_type) + + if isinstance(rt, CallableType): + callable_ret_type = get_proper_type(rt.ret_type) + if isinstance(callable_ret_type, Instance): + python_type_for_type = Instance( + callable_ret_type.type, + [python_type_for_type], + ) + else: + util.fail( + api, + "Expected Python collection type for " + "collection_class parameter", + stmt.rvalue, + ) + python_type_for_type = None + elif uselist_arg is not None and api.parse_bool(uselist_arg) is False: + if collection_cls_arg is not None: + util.fail( + api, + "Sending uselist=False and collection_class at the same time " + "does not make sense", + stmt.rvalue, + ) + if python_type_for_type is not None: + python_type_for_type = UnionType( + [python_type_for_type, NoneType()] + ) + + else: + if left_hand_explicit_type is None: + msg = ( + "Can't infer scalar or collection for ORM mapped expression " + "assigned to attribute '{}' if both 'uselist' and " + "'collection_class' arguments are absent from the " + "relationship(); please specify a " + "type annotation on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + if python_type_for_type is None: + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + if type_is_a_collection: + assert isinstance(left_hand_explicit_type, Instance) + assert isinstance(python_type_for_type, Instance) + return _infer_collection_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_explicit_type, + python_type_for_type, + ) + else: + return python_type_for_type + + +def _infer_type_from_decl_composite_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a CompositeProperty.""" + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + else: + python_type_for_type = None + + if python_type_for_type is None: + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return python_type_for_type + + +def _infer_type_from_decl_column_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a ColumnProperty. + + This includes mappings against ``column_property()`` as well as the + ``deferred()`` function. + + """ + assert isinstance(stmt.rvalue, CallExpr) + + if stmt.rvalue.args: + first_prop_arg = stmt.rvalue.args[0] + + if isinstance(first_prop_arg, CallExpr): + type_id = names.type_id_for_callee(first_prop_arg.callee) + + # look for column_property() / deferred() etc with Column as first + # argument + if type_id is names.COLUMN: + return _infer_type_from_decl_column( + api, + stmt, + node, + left_hand_explicit_type, + right_hand_expression=first_prop_arg, + ) + + if isinstance(stmt.rvalue, CallExpr): + type_id = names.type_id_for_callee(stmt.rvalue.callee) + # this is probably not strictly necessary as we have to use the left + # hand type for query expression in any case. any other no-arg + # column prop objects would go here also + if type_id is names.QUERY_EXPRESSION: + return _infer_type_from_decl_column( + api, + stmt, + node, + left_hand_explicit_type, + ) + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_decl_column( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + right_hand_expression: Optional[CallExpr] = None, +) -> Optional[ProperType]: + """Infer the type of mapping from a Column. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + a = Column(Integer) + + b = Column("b", String) + + c: Mapped[int] = Column(Integer) + + d: bool = Column(Boolean) + + Will resolve in MyPy as:: + + @reg.mapped + class MyClass: + # ... + + a : Mapped[int] + + b : Mapped[str] + + c: Mapped[int] + + d: Mapped[bool] + + """ + assert isinstance(node, Var) + + callee = None + + if right_hand_expression is None: + if not isinstance(stmt.rvalue, CallExpr): + return None + + right_hand_expression = stmt.rvalue + + for column_arg in right_hand_expression.args[0:2]: + if isinstance(column_arg, CallExpr): + if isinstance(column_arg.callee, RefExpr): + # x = Column(String(50)) + callee = column_arg.callee + type_args: Sequence[Expression] = column_arg.args + break + elif isinstance(column_arg, (NameExpr, MemberExpr)): + if isinstance(column_arg.node, TypeInfo): + # x = Column(String) + callee = column_arg + type_args = () + break + else: + # x = Column(some_name, String), go to next argument + continue + elif isinstance(column_arg, (StrExpr,)): + # x = Column("name", String), go to next argument + continue + elif isinstance(column_arg, (LambdaExpr,)): + # x = Column("name", String, default=lambda: uuid.uuid4()) + # go to next argument + continue + else: + assert False + + if callee is None: + return None + + if isinstance(callee.node, TypeInfo) and names.mro_has_id( + callee.node.mro, names.TYPEENGINE + ): + python_type_for_type = extract_python_type_from_typeengine( + api, callee.node, type_args + ) + + if left_hand_explicit_type is not None: + + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + + else: + return UnionType([python_type_for_type, NoneType()]) + else: + # it's not TypeEngine, it's typically implicitly typed + # like ForeignKey. we can't infer from the right side. + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: ProperType, + python_type_for_type: ProperType, + orig_left_hand_type: Optional[ProperType] = None, + orig_python_type_for_type: Optional[ProperType] = None, +) -> Optional[ProperType]: + """Validate type when a left hand annotation is present and we also + could infer the right hand side:: + + attrname: SomeType = Column(SomeDBType) + + """ + + if orig_left_hand_type is None: + orig_left_hand_type = left_hand_explicit_type + if orig_python_type_for_type is None: + orig_python_type_for_type = python_type_for_type + + if not is_subtype(left_hand_explicit_type, python_type_for_type): + effective_type = api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type] + ) + + msg = ( + "Left hand assignment '{}: {}' not compatible " + "with ORM mapped expression of type {}" + ) + util.fail( + api, + msg.format( + node.name, + format_type(orig_left_hand_type), + format_type(effective_type), + ), + node, + ) + + return orig_left_hand_type + + +def _infer_collection_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Instance, + python_type_for_type: Instance, +) -> Optional[ProperType]: + orig_left_hand_type = left_hand_explicit_type + orig_python_type_for_type = python_type_for_type + + if left_hand_explicit_type.args: + left_hand_arg = get_proper_type(left_hand_explicit_type.args[0]) + python_type_arg = get_proper_type(python_type_for_type.args[0]) + else: + left_hand_arg = left_hand_explicit_type + python_type_arg = python_type_for_type + + assert isinstance(left_hand_arg, (Instance, UnionType)) + assert isinstance(python_type_arg, (Instance, UnionType)) + + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_arg, + python_type_arg, + orig_left_hand_type=orig_left_hand_type, + orig_python_type_for_type=orig_python_type_for_type, + ) + + +def infer_type_from_left_hand_type_only( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Determine the type based on explicit annotation only. + + if no annotation were present, note that we need one there to know + the type. + + """ + if left_hand_explicit_type is None: + msg = ( + "Can't infer type from ORM mapped expression " + "assigned to attribute '{}'; please specify a " + "Python type or " + "Mapped[<python type>] on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + return api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)] + ) + + else: + # use type from the left hand side + return left_hand_explicit_type + + +def extract_python_type_from_typeengine( + api: SemanticAnalyzerPluginInterface, + node: TypeInfo, + type_args: Sequence[Expression], +) -> ProperType: + if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: + first_arg = type_args[0] + if isinstance(first_arg, RefExpr) and isinstance( + first_arg.node, TypeInfo + ): + for base_ in first_arg.node.mro: + if base_.fullname == "enum.Enum": + return Instance(first_arg.node, []) + # TODO: support other pep-435 types here + else: + return api.named_type(names.NAMED_TYPE_BUILTINS_STR, []) + + assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), ( + "could not extract Python type from node: %s" % node + ) + + type_engine_sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.sql.type_api.TypeEngine" + ) + + assert type_engine_sym is not None and isinstance( + type_engine_sym.node, TypeInfo + ) + type_engine = map_instance_to_supertype( + Instance(node, []), + type_engine_sym.node, + ) + return get_proper_type(type_engine.args[-1]) diff --git a/lib/sqlalchemy/ext/mypy/names.py b/lib/sqlalchemy/ext/mypy/names.py new file mode 100644 index 0000000..8ec15a6 --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/names.py @@ -0,0 +1,253 @@ +# ext/mypy/names.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from mypy.nodes import ClassDef +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import SymbolNode +from mypy.nodes import TypeAlias +from mypy.nodes import TypeInfo +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import UnboundType + +from ... import util + +COLUMN: int = util.symbol("COLUMN") # type: ignore +RELATIONSHIP: int = util.symbol("RELATIONSHIP") # type: ignore +REGISTRY: int = util.symbol("REGISTRY") # type: ignore +COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore +TYPEENGINE: int = util.symbol("TYPEENGNE") # type: ignore +MAPPED: int = util.symbol("MAPPED") # type: ignore +DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") # type: ignore +DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") # type: ignore +MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") # type: ignore +COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") # type: ignore +SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") # type: ignore +COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") # type: ignore +DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") # type: ignore +MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") # type: ignore +AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") # type: ignore +AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") # type: ignore +DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") # type: ignore +QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") # type: ignore + +# names that must succeed with mypy.api.named_type +NAMED_TYPE_BUILTINS_OBJECT = "builtins.object" +NAMED_TYPE_BUILTINS_STR = "builtins.str" +NAMED_TYPE_BUILTINS_LIST = "builtins.list" +NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.attributes.Mapped" + +_lookup: Dict[str, Tuple[int, Set[str]]] = { + "Column": ( + COLUMN, + { + "sqlalchemy.sql.schema.Column", + "sqlalchemy.sql.Column", + }, + ), + "RelationshipProperty": ( + RELATIONSHIP, + { + "sqlalchemy.orm.relationships.RelationshipProperty", + "sqlalchemy.orm.RelationshipProperty", + }, + ), + "registry": ( + REGISTRY, + { + "sqlalchemy.orm.decl_api.registry", + "sqlalchemy.orm.registry", + }, + ), + "ColumnProperty": ( + COLUMN_PROPERTY, + { + "sqlalchemy.orm.properties.ColumnProperty", + "sqlalchemy.orm.ColumnProperty", + }, + ), + "SynonymProperty": ( + SYNONYM_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.SynonymProperty", + "sqlalchemy.orm.SynonymProperty", + }, + ), + "CompositeProperty": ( + COMPOSITE_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.CompositeProperty", + "sqlalchemy.orm.CompositeProperty", + }, + ), + "MapperProperty": ( + MAPPER_PROPERTY, + { + "sqlalchemy.orm.interfaces.MapperProperty", + "sqlalchemy.orm.MapperProperty", + }, + ), + "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}), + "Mapped": (MAPPED, {"sqlalchemy.orm.attributes.Mapped"}), + "declarative_base": ( + DECLARATIVE_BASE, + { + "sqlalchemy.ext.declarative.declarative_base", + "sqlalchemy.orm.declarative_base", + "sqlalchemy.orm.decl_api.declarative_base", + }, + ), + "DeclarativeMeta": ( + DECLARATIVE_META, + { + "sqlalchemy.ext.declarative.DeclarativeMeta", + "sqlalchemy.orm.DeclarativeMeta", + "sqlalchemy.orm.decl_api.DeclarativeMeta", + }, + ), + "mapped": ( + MAPPED_DECORATOR, + { + "sqlalchemy.orm.decl_api.registry.mapped", + "sqlalchemy.orm.registry.mapped", + }, + ), + "as_declarative": ( + AS_DECLARATIVE, + { + "sqlalchemy.ext.declarative.as_declarative", + "sqlalchemy.orm.decl_api.as_declarative", + "sqlalchemy.orm.as_declarative", + }, + ), + "as_declarative_base": ( + AS_DECLARATIVE_BASE, + { + "sqlalchemy.orm.decl_api.registry.as_declarative_base", + "sqlalchemy.orm.registry.as_declarative_base", + }, + ), + "declared_attr": ( + DECLARED_ATTR, + { + "sqlalchemy.orm.decl_api.declared_attr", + "sqlalchemy.orm.declared_attr", + }, + ), + "declarative_mixin": ( + DECLARATIVE_MIXIN, + { + "sqlalchemy.orm.decl_api.declarative_mixin", + "sqlalchemy.orm.declarative_mixin", + }, + ), + "query_expression": ( + QUERY_EXPRESSION, + {"sqlalchemy.orm.query_expression"}, + ), +} + + +def has_base_type_id(info: TypeInfo, type_id: int) -> bool: + for mr in info.mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool: + for mr in mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def type_id_for_unbound_type( + type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> Optional[int]: + sym = api.lookup_qualified(type_.name, type_) + if sym is not None: + if isinstance(sym.node, TypeAlias): + target_type = get_proper_type(sym.node.target) + if isinstance(target_type, Instance): + return type_id_for_named_node(target_type.type) + elif isinstance(sym.node, TypeInfo): + return type_id_for_named_node(sym.node) + + return None + + +def type_id_for_callee(callee: Expression) -> Optional[int]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return type_id_for_fullname(target_type.type.fullname) + elif isinstance(callee.node, TypeInfo): + return type_id_for_named_node(callee) + return None + + +def type_id_for_named_node( + node: Union[NameExpr, MemberExpr, SymbolNode] +) -> Optional[int]: + type_id, fullnames = _lookup.get(node.name, (None, None)) + + if type_id is None or fullnames is None: + return None + elif node.fullname in fullnames: + return type_id + else: + return None + + +def type_id_for_fullname(fullname: str) -> Optional[int]: + tokens = fullname.split(".") + immediate = tokens[-1] + + type_id, fullnames = _lookup.get(immediate, (None, None)) + + if type_id is None or fullnames is None: + return None + elif fullname in fullnames: + return type_id + else: + return None diff --git a/lib/sqlalchemy/ext/mypy/plugin.py b/lib/sqlalchemy/ext/mypy/plugin.py new file mode 100644 index 0000000..8687012 --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/plugin.py @@ -0,0 +1,284 @@ +# ext/mypy/plugin.py +# Copyright (C) 2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +Mypy plugin for SQLAlchemy ORM. + +""" +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type as TypingType +from typing import Union + +from mypy import nodes +from mypy.mro import calculate_mro +from mypy.mro import MroError +from mypy.nodes import Block +from mypy.nodes import ClassDef +from mypy.nodes import GDEF +from mypy.nodes import MypyFile +from mypy.nodes import NameExpr +from mypy.nodes import SymbolTable +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeInfo +from mypy.plugin import AttributeContext +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import Plugin +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import Type + +from . import decl_class +from . import names +from . import util + + +class SQLAlchemyPlugin(Plugin): + def get_dynamic_class_hook( + self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE: + return _dynamic_class_hook + return None + + def get_customize_class_mro_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return _fill_in_decorators + + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + + sym = self.lookup_fully_qualified(fullname) + + if sym is not None and sym.node is not None: + type_id = names.type_id_for_named_node(sym.node) + if type_id is names.MAPPED_DECORATOR: + return _cls_decorator_hook + elif type_id in ( + names.AS_DECLARATIVE, + names.AS_DECLARATIVE_BASE, + ): + return _base_cls_decorator_hook + elif type_id is names.DECLARATIVE_MIXIN: + return _declarative_mixin_hook + + return None + + def get_metaclass_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META: + # Set any classes that explicitly have metaclass=DeclarativeMeta + # as declarative so the check in `get_base_class_hook()` works + return _metaclass_cls_hook + + return None + + def get_base_class_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + sym = self.lookup_fully_qualified(fullname) + + if ( + sym + and isinstance(sym.node, TypeInfo) + and util.has_declarative_base(sym.node) + ): + return _base_cls_hook + + return None + + def get_attribute_hook( + self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + if fullname.startswith( + "sqlalchemy.orm.attributes.QueryableAttribute." + ): + return _queryable_getattr_hook + + return None + + def get_additional_deps( + self, file: MypyFile + ) -> List[Tuple[int, str, int]]: + return [ + (10, "sqlalchemy.orm.attributes", -1), + (10, "sqlalchemy.orm.decl_api", -1), + ] + + +def plugin(version: str) -> TypingType[SQLAlchemyPlugin]: + return SQLAlchemyPlugin + + +def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: + """Generate a declarative Base class when the declarative_base() function + is encountered.""" + + _add_globals(ctx) + + cls = ClassDef(ctx.name, Block([])) + cls.fullname = ctx.api.qualified_name(ctx.name) + + info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id) + cls.info = info + _set_declarative_metaclass(ctx.api, cls) + + cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,)) + if cls_arg is not None and isinstance(cls_arg.node, TypeInfo): + util.set_is_base(cls_arg.node) + decl_class.scan_declarative_assignments_and_apply_types( + cls_arg.node.defn, ctx.api, is_mixin_scan=True + ) + info.bases = [Instance(cls_arg.node, [])] + else: + obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT) + + info.bases = [obj] + + try: + calculate_mro(info) + except MroError: + util.fail( + ctx.api, "Not able to calculate MRO for declarative base", ctx.call + ) + obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT) + info.bases = [obj] + info.fallback_to_any = True + + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + util.set_is_base(info) + + +def _fill_in_decorators(ctx: ClassDefContext) -> None: + for decorator in ctx.cls.decorators: + # set the ".fullname" attribute of a class decorator + # that is a MemberExpr. This causes the logic in + # semanal.py->apply_class_plugin_hooks to invoke the + # get_class_decorator_hook for our "registry.map_class()" + # and "registry.as_declarative_base()" methods. + # this seems like a bug in mypy that these decorators are otherwise + # skipped. + + if ( + isinstance(decorator, nodes.CallExpr) + and isinstance(decorator.callee, nodes.MemberExpr) + and decorator.callee.name == "as_declarative_base" + ): + target = decorator.callee + elif ( + isinstance(decorator, nodes.MemberExpr) + and decorator.name == "mapped" + ): + target = decorator + else: + continue + + assert isinstance(target.expr, NameExpr) + sym = ctx.api.lookup_qualified( + target.expr.name, target, suppress_errors=True + ) + if sym and sym.node: + sym_type = get_proper_type(sym.type) + if isinstance(sym_type, Instance): + target.fullname = f"{sym_type.type.fullname}.{target.name}" + else: + # if the registry is in the same file as where the + # decorator is used, it might not have semantic + # symbols applied and we can't get a fully qualified + # name or an inferred type, so we are actually going to + # flag an error in this case that they need to annotate + # it. The "registry" is declared just + # once (or few times), so they have to just not use + # type inference for its assignment in this one case. + util.fail( + ctx.api, + "Class decorator called %s(), but we can't " + "tell if it's from an ORM registry. Please " + "annotate the registry assignment, e.g. " + "my_registry: registry = registry()" % target.name, + sym.node, + ) + + +def _cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + assert isinstance(ctx.reason, nodes.MemberExpr) + expr = ctx.reason.expr + + assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var) + + node_type = get_proper_type(expr.node.type) + + assert ( + isinstance(node_type, Instance) + and names.type_id_for_named_node(node_type.type) is names.REGISTRY + ) + + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _base_cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + + cls = ctx.cls + + _set_declarative_metaclass(ctx.api, cls) + + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + cls, ctx.api, is_mixin_scan=True + ) + + +def _declarative_mixin_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + ctx.cls, ctx.api, is_mixin_scan=True + ) + + +def _metaclass_cls_hook(ctx: ClassDefContext) -> None: + util.set_is_base(ctx.cls.info) + + +def _base_cls_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _queryable_getattr_hook(ctx: AttributeContext) -> Type: + # how do I....tell it it has no attribute of a certain name? + # can't find any Type that seems to match that + return ctx.default_attr_type + + +def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None: + """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space + for all class defs + + """ + + util.add_global(ctx, "sqlalchemy.orm.attributes", "Mapped", "__sa_Mapped") + + +def _set_declarative_metaclass( + api: SemanticAnalyzerPluginInterface, target_cls: ClassDef +) -> None: + info = target_cls.info + sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.orm.decl_api.DeclarativeMeta" + ) + assert sym is not None and isinstance(sym.node, TypeInfo) + info.declared_metaclass = info.metaclass_type = Instance(sym.node, []) diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py new file mode 100644 index 0000000..16b365e --- /dev/null +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -0,0 +1,305 @@ +import re +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type as TypingType +from typing import TypeVar +from typing import Union + +from mypy.nodes import ARG_POS +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import CLASSDEF_NO_INFO +from mypy.nodes import Context +from mypy.nodes import Expression +from mypy.nodes import IfStmt +from mypy.nodes import JsonDict +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import Statement +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeInfo +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import deserialize_and_fixup_type +from mypy.typeops import map_type_from_supertype +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import Type +from mypy.types import TypeVarType +from mypy.types import UnboundType +from mypy.types import UnionType + + +_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr]) + + +class SQLAlchemyAttribute: + def __init__( + self, + name: str, + line: int, + column: int, + typ: Optional[Type], + info: TypeInfo, + ) -> None: + self.name = name + self.line = line + self.column = column + self.type = typ + self.info = info + + def serialize(self) -> JsonDict: + assert self.type + return { + "name": self.name, + "line": self.line, + "column": self.column, + "type": self.type.serialize(), + } + + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is + inherited from a generic super type. + """ + if not isinstance(self.type, TypeVarType): + return + + self.type = map_type_from_supertype(self.type, sub_type, self.info) + + @classmethod + def deserialize( + cls, + info: TypeInfo, + data: JsonDict, + api: SemanticAnalyzerPluginInterface, + ) -> "SQLAlchemyAttribute": + data = data.copy() + typ = deserialize_and_fixup_type(data.pop("type"), api) + return cls(typ=typ, info=info, **data) + + +def name_is_dunder(name): + return bool(re.match(r"^__.+?__$", name)) + + +def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None: + info.metadata.setdefault("sqlalchemy", {})[key] = data + + +def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]: + return info.metadata.get("sqlalchemy", {}).get(key, None) + + +def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]: + if info.mro: + for base in info.mro: + metadata = _get_info_metadata(base, key) + if metadata is not None: + return metadata + return None + + +def establish_as_sqlalchemy(info: TypeInfo) -> None: + info.metadata.setdefault("sqlalchemy", {}) + + +def set_is_base(info: TypeInfo) -> None: + _set_info_metadata(info, "is_base", True) + + +def get_is_base(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "is_base") + return is_base is True + + +def has_declarative_base(info: TypeInfo) -> bool: + is_base = _get_info_mro_metadata(info, "is_base") + return is_base is True + + +def set_has_table(info: TypeInfo) -> None: + _set_info_metadata(info, "has_table", True) + + +def get_has_table(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "has_table") + return is_base is True + + +def get_mapped_attributes( + info: TypeInfo, api: SemanticAnalyzerPluginInterface +) -> Optional[List[SQLAlchemyAttribute]]: + mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata( + info, "mapped_attributes" + ) + if mapped_attributes is None: + return None + + attributes: List[SQLAlchemyAttribute] = [] + + for data in mapped_attributes: + attr = SQLAlchemyAttribute.deserialize(info, data, api) + attr.expand_typevar_from_subtype(info) + attributes.append(attr) + + return attributes + + +def set_mapped_attributes( + info: TypeInfo, attributes: List[SQLAlchemyAttribute] +) -> None: + _set_info_metadata( + info, + "mapped_attributes", + [attribute.serialize() for attribute in attributes], + ) + + +def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None: + msg = "[SQLAlchemy Mypy plugin] %s" % msg + return api.fail(msg, ctx) + + +def add_global( + ctx: Union[ClassDefContext, DynamicClassDefContext], + module: str, + symbol_name: str, + asname: str, +) -> None: + module_globals = ctx.api.modules[ctx.api.cur_mod_id].names + + if asname not in module_globals: + lookup_sym: SymbolTableNode = ctx.api.modules[module].names[ + symbol_name + ] + + module_globals[asname] = lookup_sym + + +@overload +def get_callexpr_kwarg( + callexpr: CallExpr, name: str, *, expr_types: None = ... +) -> Optional[Union[CallExpr, NameExpr]]: + ... + + +@overload +def get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Tuple[TypingType[_TArgType], ...] +) -> Optional[_TArgType]: + ... + + +def get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Optional[Tuple[TypingType[Any], ...]] = None +) -> Optional[Any]: + try: + arg_idx = callexpr.arg_names.index(name) + except ValueError: + return None + + kwarg = callexpr.args[arg_idx] + if isinstance( + kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr) + ): + return kwarg + + return None + + +def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: + for stmt in stmts: + if ( + isinstance(stmt, IfStmt) + and isinstance(stmt.expr[0], NameExpr) + and stmt.expr[0].fullname == "typing.TYPE_CHECKING" + ): + for substmt in stmt.body[0].body: + yield substmt + else: + yield stmt + + +def unbound_to_instance( + api: SemanticAnalyzerPluginInterface, typ: Type +) -> Type: + """Take the UnboundType that we seem to get as the ret_type from a FuncDef + and convert it into an Instance/TypeInfo kind of structure that seems + to work as the left-hand type of an AssignmentStatement. + + """ + + if not isinstance(typ, UnboundType): + return typ + + # TODO: figure out a more robust way to check this. The node is some + # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm, + # but I cant figure out how to get them to match up + if typ.name == "Optional": + # convert from "Optional?" to the more familiar + # UnionType[..., NoneType()] + return unbound_to_instance( + api, + UnionType( + [unbound_to_instance(api, typ_arg) for typ_arg in typ.args] + + [NoneType()] + ), + ) + + node = api.lookup_qualified(typ.name, typ) + + if ( + node is not None + and isinstance(node, SymbolTableNode) + and isinstance(node.node, TypeInfo) + ): + bound_type = node.node + + return Instance( + bound_type, + [ + unbound_to_instance(api, arg) + if isinstance(arg, UnboundType) + else arg + for arg in typ.args + ], + ) + else: + return typ + + +def info_for_cls( + cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> Optional[TypeInfo]: + if cls.info is CLASSDEF_NO_INFO: + sym = api.lookup_qualified(cls.name, cls) + if sym is None: + return None + assert sym and isinstance(sym.node, TypeInfo) + return sym.node + + return cls.info + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = "sqlalchemy.orm.attributes.Mapped" + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py new file mode 100644 index 0000000..5a327d1 --- /dev/null +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -0,0 +1,388 @@ +# ext/orderinglist.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""A custom list that manages index/position information for contained +elements. + +:author: Jason Kirtland + +``orderinglist`` is a helper for mutable ordered relationships. It will +intercept list operations performed on a :func:`_orm.relationship`-managed +collection and +automatically synchronize changes in list position onto a target scalar +attribute. + +Example: A ``slide`` table, where each row refers to zero or more entries +in a related ``bullet`` table. The bullets within a slide are +displayed in order based on the value of the ``position`` column in the +``bullet`` table. As entries are reordered in memory, the value of the +``position`` attribute should be updated to reflect the new sort order:: + + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position") + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +The standard relationship mapping will produce a list-like attribute on each +``Slide`` containing all related ``Bullet`` objects, +but coping with changes in ordering is not handled automatically. +When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position`` +attribute will remain unset until manually assigned. When the ``Bullet`` +is inserted into the middle of the list, the following ``Bullet`` objects +will also need to be renumbered. + +The :class:`.OrderingList` object automates this task, managing the +``position`` attribute on all ``Bullet`` objects in the collection. It is +constructed using the :func:`.ordering_list` factory:: + + from sqlalchemy.ext.orderinglist import ordering_list + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +With the above mapping the ``Bullet.position`` attribute is managed:: + + s = Slide() + s.bullets.append(Bullet()) + s.bullets.append(Bullet()) + s.bullets[1].position + >>> 1 + s.bullets.insert(1, Bullet()) + s.bullets[2].position + >>> 2 + +The :class:`.OrderingList` construct only works with **changes** to a +collection, and not the initial load from the database, and requires that the +list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the +:func:`_orm.relationship` against the target ordering attribute, so that the +ordering is correct when first loaded. + +.. warning:: + + :class:`.OrderingList` only provides limited functionality when a primary + key column or unique column is the target of the sort. Operations + that are unsupported or are problematic include: + + * two entries must trade values. This is not supported directly in the + case of a primary key or unique constraint because it means at least + one row would need to be temporarily removed first, or changed to + a third, neutral value while the switch occurs. + + * an entry must be deleted in order to make room for a new entry. + SQLAlchemy's unit of work performs all INSERTs before DELETEs within a + single flush. In the case of a primary key, it will trade + an INSERT/DELETE of the same primary key for an UPDATE statement in order + to lessen the impact of this limitation, however this does not take place + for a UNIQUE column. + A future feature will allow the "DELETE before INSERT" behavior to be + possible, alleviating this limitation, though this feature will require + explicit configuration at the mapper level for sets of columns that + are to be handled in this way. + +:func:`.ordering_list` takes the name of the related object's ordering +attribute as an argument. By default, the zero-based integer index of the +object's position in the :func:`.ordering_list` is synchronized with the +ordering attribute: index 0 will get position 0, index 1 position 1, etc. To +start numbering at 1 or some other integer, provide ``count_from=1``. + + +""" +from ..orm.collections import collection +from ..orm.collections import collection_adapter + + +__all__ = ["ordering_list"] + + +def ordering_list(attr, count_from=None, **kw): + """Prepares an :class:`OrderingList` factory for use in mapper definitions. + + Returns an object suitable for use as an argument to a Mapper + relationship's ``collection_class`` option. e.g.:: + + from sqlalchemy.ext.orderinglist import ordering_list + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + :param attr: + Name of the mapped attribute to use for storage and retrieval of + ordering information + + :param count_from: + Set up an integer-based ordering, starting at ``count_from``. For + example, ``ordering_list('pos', count_from=1)`` would create a 1-based + list in SQL, storing the value in the 'pos' column. Ignored if + ``ordering_func`` is supplied. + + Additional arguments are passed to the :class:`.OrderingList` constructor. + + """ + + kw = _unsugar_count_from(count_from=count_from, **kw) + return lambda: OrderingList(attr, **kw) + + +# Ordering utility functions + + +def count_from_0(index, collection): + """Numbering function: consecutive integers starting at 0.""" + + return index + + +def count_from_1(index, collection): + """Numbering function: consecutive integers starting at 1.""" + + return index + 1 + + +def count_from_n_factory(start): + """Numbering function: consecutive integers starting at arbitrary start.""" + + def f(index, collection): + return index + start + + try: + f.__name__ = "count_from_%i" % start + except TypeError: + pass + return f + + +def _unsugar_count_from(**kw): + """Builds counting functions from keyword arguments. + + Keyword argument filter, prepares a simple ``ordering_func`` from a + ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. + """ + + count_from = kw.pop("count_from", None) + if kw.get("ordering_func", None) is None and count_from is not None: + if count_from == 0: + kw["ordering_func"] = count_from_0 + elif count_from == 1: + kw["ordering_func"] = count_from_1 + else: + kw["ordering_func"] = count_from_n_factory(count_from) + return kw + + +class OrderingList(list): + """A custom list that manages position information for its children. + + The :class:`.OrderingList` object is normally set up using the + :func:`.ordering_list` factory function, used in conjunction with + the :func:`_orm.relationship` function. + + """ + + def __init__( + self, ordering_attr=None, ordering_func=None, reorder_on_append=False + ): + """A custom list that manages position information for its children. + + ``OrderingList`` is a ``collection_class`` list implementation that + syncs position in a Python list with a position attribute on the + mapped objects. + + This implementation relies on the list starting in the proper order, + so be **sure** to put an ``order_by`` on your relationship. + + :param ordering_attr: + Name of the attribute that stores the object's order in the + relationship. + + :param ordering_func: Optional. A function that maps the position in + the Python list to a value to store in the + ``ordering_attr``. Values returned are usually (but need not be!) + integers. + + An ``ordering_func`` is called with two positional parameters: the + index of the element in the list, and the list itself. + + If omitted, Python list indexes are used for the attribute values. + Two basic pre-built numbering functions are provided in this module: + ``count_from_0`` and ``count_from_1``. For more exotic examples + like stepped numbering, alphabetical and Fibonacci numbering, see + the unit tests. + + :param reorder_on_append: + Default False. When appending an object with an existing (non-None) + ordering value, that value will be left untouched unless + ``reorder_on_append`` is true. This is an optimization to avoid a + variety of dangerous unexpected database writes. + + SQLAlchemy will add instances to the list via append() when your + object loads. If for some reason the result set from the database + skips a step in the ordering (say, row '1' is missing but you get + '2', '3', and '4'), reorder_on_append=True would immediately + renumber the items to '1', '2', '3'. If you have multiple sessions + making changes, any of whom happen to load this collection even in + passing, all of the sessions would try to "clean up" the numbering + in their commits, possibly causing all but one to fail with a + concurrent modification error. + + Recommend leaving this with the default of False, and just call + ``reorder()`` if you're doing ``append()`` operations with + previously ordered instances or when doing some housekeeping after + manual sql operations. + + """ + self.ordering_attr = ordering_attr + if ordering_func is None: + ordering_func = count_from_0 + self.ordering_func = ordering_func + self.reorder_on_append = reorder_on_append + + # More complex serialization schemes (multi column, e.g.) are possible by + # subclassing and reimplementing these two methods. + def _get_order_value(self, entity): + return getattr(entity, self.ordering_attr) + + def _set_order_value(self, entity, value): + setattr(entity, self.ordering_attr, value) + + def reorder(self): + """Synchronize ordering for the entire collection. + + Sweeps through the list and ensures that each object has accurate + ordering information set. + + """ + for index, entity in enumerate(self): + self._order_entity(index, entity, True) + + # As of 0.5, _reorder is no longer semi-private + _reorder = reorder + + def _order_entity(self, index, entity, reorder=True): + have = self._get_order_value(entity) + + # Don't disturb existing ordering if reorder is False + if have is not None and not reorder: + return + + should_be = self.ordering_func(index, self) + if have != should_be: + self._set_order_value(entity, should_be) + + def append(self, entity): + super(OrderingList, self).append(entity) + self._order_entity(len(self) - 1, entity, self.reorder_on_append) + + def _raw_append(self, entity): + """Append without any ordering behavior.""" + + super(OrderingList, self).append(entity) + + _raw_append = collection.adds(1)(_raw_append) + + def insert(self, index, entity): + super(OrderingList, self).insert(index, entity) + self._reorder() + + def remove(self, entity): + super(OrderingList, self).remove(entity) + + adapter = collection_adapter(self) + if adapter and adapter._referenced_by_owner: + self._reorder() + + def pop(self, index=-1): + entity = super(OrderingList, self).pop(index) + self._reorder() + return entity + + def __setitem__(self, index, entity): + if isinstance(index, slice): + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + + for i in range(start, stop, step): + self.__setitem__(i, entity[i]) + else: + self._order_entity(index, entity, True) + super(OrderingList, self).__setitem__(index, entity) + + def __delitem__(self, index): + super(OrderingList, self).__delitem__(index) + self._reorder() + + def __setslice__(self, start, end, values): + super(OrderingList, self).__setslice__(start, end, values) + self._reorder() + + def __delslice__(self, start, end): + super(OrderingList, self).__delslice__(start, end) + self._reorder() + + def __reduce__(self): + return _reconstitute, (self.__class__, self.__dict__, list(self)) + + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +def _reconstitute(cls, dict_, items): + """Reconstitute an :class:`.OrderingList`. + + This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for + unpickling :class:`.OrderingList` objects. + + """ + obj = cls.__new__(cls) + obj.__dict__.update(dict_) + list.extend(obj, items) + return obj diff --git a/lib/sqlalchemy/ext/serializer.py b/lib/sqlalchemy/ext/serializer.py new file mode 100644 index 0000000..094b71b --- /dev/null +++ b/lib/sqlalchemy/ext/serializer.py @@ -0,0 +1,177 @@ +# ext/serializer.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, +allowing "contextual" deserialization. + +Any SQLAlchemy query structure, either based on sqlalchemy.sql.* +or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session +etc. which are referenced by the structure are not persisted in serialized +form, but are instead re-associated with the query structure +when it is deserialized. + +Usage is nearly the same as that of the standard Python pickle module:: + + from sqlalchemy.ext.serializer import loads, dumps + metadata = MetaData(bind=some_engine) + Session = scoped_session(sessionmaker()) + + # ... define mappers + + query = Session.query(MyClass). + filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) + + # pickle the query + serialized = dumps(query) + + # unpickle. Pass in metadata + scoped_session + query2 = loads(serialized, metadata, Session) + + print query2.all() + +Similar restrictions as when using raw pickle apply; mapped classes must be +themselves be pickleable, meaning they are importable from a module-level +namespace. + +The serializer module is only appropriate for query structures. It is not +needed for: + +* instances of user-defined classes. These contain no references to engines, + sessions or expression constructs in the typical case and can be serialized + directly. + +* Table metadata that is to be loaded entirely from the serialized structure + (i.e. is not already declared in the application). Regular + pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object, + typically one which was reflected from an existing database at some previous + point in time. The serializer module is specifically for the opposite case, + where the Table metadata is already present in memory. + +""" + +import re + +from .. import Column +from .. import Table +from ..engine import Engine +from ..orm import class_mapper +from ..orm.interfaces import MapperProperty +from ..orm.mapper import Mapper +from ..orm.session import Session +from ..util import b64decode +from ..util import b64encode +from ..util import byte_buffer +from ..util import pickle +from ..util import text_type + + +__all__ = ["Serializer", "Deserializer", "dumps", "loads"] + + +def Serializer(*args, **kw): + pickler = pickle.Pickler(*args, **kw) + + def persistent_id(obj): + # print "serializing:", repr(obj) + if isinstance(obj, Mapper) and not obj.non_primary: + id_ = "mapper:" + b64encode(pickle.dumps(obj.class_)) + elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: + id_ = ( + "mapperprop:" + + b64encode(pickle.dumps(obj.parent.class_)) + + ":" + + obj.key + ) + elif isinstance(obj, Table): + if "parententity" in obj._annotations: + id_ = "mapper_selectable:" + b64encode( + pickle.dumps(obj._annotations["parententity"].class_) + ) + else: + id_ = "table:" + text_type(obj.key) + elif isinstance(obj, Column) and isinstance(obj.table, Table): + id_ = ( + "column:" + text_type(obj.table.key) + ":" + text_type(obj.key) + ) + elif isinstance(obj, Session): + id_ = "session:" + elif isinstance(obj, Engine): + id_ = "engine:" + else: + return None + return id_ + + pickler.persistent_id = persistent_id + return pickler + + +our_ids = re.compile( + r"(mapperprop|mapper|mapper_selectable|table|column|" + r"session|attribute|engine):(.*)" +) + + +def Deserializer(file, metadata=None, scoped_session=None, engine=None): + unpickler = pickle.Unpickler(file) + + def get_engine(): + if engine: + return engine + elif scoped_session and scoped_session().bind: + return scoped_session().bind + elif metadata and metadata.bind: + return metadata.bind + else: + return None + + def persistent_load(id_): + m = our_ids.match(text_type(id_)) + if not m: + return None + else: + type_, args = m.group(1, 2) + if type_ == "attribute": + key, clsarg = args.split(":") + cls = pickle.loads(b64decode(clsarg)) + return getattr(cls, key) + elif type_ == "mapper": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls) + elif type_ == "mapper_selectable": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls).__clause_element__() + elif type_ == "mapperprop": + mapper, keyname = args.split(":") + cls = pickle.loads(b64decode(mapper)) + return class_mapper(cls).attrs[keyname] + elif type_ == "table": + return metadata.tables[args] + elif type_ == "column": + table, colname = args.split(":") + return metadata.tables[table].c[colname] + elif type_ == "session": + return scoped_session() + elif type_ == "engine": + return get_engine() + else: + raise Exception("Unknown token: %s" % type_) + + unpickler.persistent_load = persistent_load + return unpickler + + +def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL): + buf = byte_buffer() + pickler = Serializer(buf, protocol) + pickler.dump(obj) + return buf.getvalue() + + +def loads(data, metadata=None, scoped_session=None, engine=None): + buf = byte_buffer(data) + unpickler = Deserializer(buf, metadata, scoped_session, engine) + return unpickler.load() |