diff options
Diffstat (limited to 'lib/sqlalchemy/util')
-rw-r--r-- | lib/sqlalchemy/util/__init__.py | 175 | ||||
-rw-r--r-- | lib/sqlalchemy/util/_collections.py | 1089 | ||||
-rw-r--r-- | lib/sqlalchemy/util/_compat_py3k.py | 67 | ||||
-rw-r--r-- | lib/sqlalchemy/util/_concurrency_py3k.py | 194 | ||||
-rw-r--r-- | lib/sqlalchemy/util/_preloaded.py | 68 | ||||
-rw-r--r-- | lib/sqlalchemy/util/compat.py | 632 | ||||
-rw-r--r-- | lib/sqlalchemy/util/concurrency.py | 73 | ||||
-rw-r--r-- | lib/sqlalchemy/util/deprecations.py | 417 | ||||
-rw-r--r-- | lib/sqlalchemy/util/langhelpers.py | 1945 | ||||
-rw-r--r-- | lib/sqlalchemy/util/queue.py | 291 | ||||
-rw-r--r-- | lib/sqlalchemy/util/topological.py | 100 |
11 files changed, 5051 insertions, 0 deletions
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py new file mode 100644 index 0000000..33427e3 --- /dev/null +++ b/lib/sqlalchemy/util/__init__.py @@ -0,0 +1,175 @@ +# util/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +from collections import defaultdict +from contextlib import contextmanager +from functools import partial +from functools import update_wrapper + +from ._collections import coerce_generator_arg +from ._collections import coerce_to_immutabledict +from ._collections import collections_abc +from ._collections import column_dict +from ._collections import column_set +from ._collections import EMPTY_DICT +from ._collections import EMPTY_SET +from ._collections import FacadeDict +from ._collections import flatten_iterator +from ._collections import has_dupes +from ._collections import has_intersection +from ._collections import IdentitySet +from ._collections import ImmutableContainer +from ._collections import immutabledict +from ._collections import ImmutableProperties +from ._collections import LRUCache +from ._collections import ordered_column_set +from ._collections import OrderedDict +from ._collections import OrderedIdentitySet +from ._collections import OrderedProperties +from ._collections import OrderedSet +from ._collections import PopulateDict +from ._collections import Properties +from ._collections import ScopedRegistry +from ._collections import sort_dictionary +from ._collections import ThreadLocalRegistry +from ._collections import to_column_set +from ._collections import to_list +from ._collections import to_set +from ._collections import unique_list +from ._collections import UniqueAppender +from ._collections import update_copy +from ._collections import WeakPopulateDict +from ._collections import WeakSequence +from ._preloaded import preload_module +from ._preloaded import preloaded +from .compat import ABC +from .compat import arm +from .compat import b +from .compat import b64decode +from .compat import b64encode +from .compat import binary_type +from .compat import binary_types +from .compat import byte_buffer +from .compat import callable +from .compat import cmp +from .compat import cpython +from .compat import dataclass_fields +from .compat import decode_backslashreplace +from .compat import dottedgetter +from .compat import has_refcount_gc +from .compat import inspect_getfullargspec +from .compat import int_types +from .compat import iterbytes +from .compat import itertools_filter +from .compat import itertools_filterfalse +from .compat import local_dataclass_fields +from .compat import namedtuple +from .compat import next +from .compat import nullcontext +from .compat import osx +from .compat import parse_qsl +from .compat import perf_counter +from .compat import pickle +from .compat import print_ +from .compat import py2k +from .compat import py311 +from .compat import py37 +from .compat import py38 +from .compat import py39 +from .compat import py3k +from .compat import pypy +from .compat import quote_plus +from .compat import raise_ +from .compat import raise_from_cause +from .compat import reduce +from .compat import reraise +from .compat import string_types +from .compat import StringIO +from .compat import text_type +from .compat import threading +from .compat import timezone +from .compat import TYPE_CHECKING +from .compat import u +from .compat import ue +from .compat import unquote +from .compat import unquote_plus +from .compat import win32 +from .compat import with_metaclass +from .compat import zip_longest +from .concurrency import asyncio +from .concurrency import await_fallback +from .concurrency import await_only +from .concurrency import greenlet_spawn +from .concurrency import is_exit_exception +from .deprecations import deprecated +from .deprecations import deprecated_20 +from .deprecations import deprecated_20_cls +from .deprecations import deprecated_cls +from .deprecations import deprecated_params +from .deprecations import inject_docstring_text +from .deprecations import moved_20 +from .deprecations import SQLALCHEMY_WARN_20 +from .deprecations import warn_deprecated +from .deprecations import warn_deprecated_20 +from .langhelpers import add_parameter_text +from .langhelpers import as_interface +from .langhelpers import asbool +from .langhelpers import asint +from .langhelpers import assert_arg_type +from .langhelpers import attrsetter +from .langhelpers import bool_or_str +from .langhelpers import chop_traceback +from .langhelpers import class_hierarchy +from .langhelpers import classproperty +from .langhelpers import clsname_as_plain_name +from .langhelpers import coerce_kw_type +from .langhelpers import constructor_copy +from .langhelpers import constructor_key +from .langhelpers import counter +from .langhelpers import create_proxy_methods +from .langhelpers import decode_slice +from .langhelpers import decorator +from .langhelpers import dictlike_iteritems +from .langhelpers import duck_type_collection +from .langhelpers import ellipses_string +from .langhelpers import EnsureKWArgType +from .langhelpers import format_argspec_init +from .langhelpers import format_argspec_plus +from .langhelpers import generic_repr +from .langhelpers import get_callable_argspec +from .langhelpers import get_cls_kwargs +from .langhelpers import get_func_kwargs +from .langhelpers import getargspec_init +from .langhelpers import has_compiled_ext +from .langhelpers import HasMemoized +from .langhelpers import hybridmethod +from .langhelpers import hybridproperty +from .langhelpers import iterate_attributes +from .langhelpers import map_bits +from .langhelpers import md5_hex +from .langhelpers import memoized_instancemethod +from .langhelpers import memoized_property +from .langhelpers import MemoizedSlots +from .langhelpers import method_is_overridden +from .langhelpers import methods_equivalent +from .langhelpers import monkeypatch_proxied_specials +from .langhelpers import NoneType +from .langhelpers import only_once +from .langhelpers import PluginLoader +from .langhelpers import portable_instancemethod +from .langhelpers import quoted_token_parser +from .langhelpers import safe_reraise +from .langhelpers import set_creation_order +from .langhelpers import string_or_unprintable +from .langhelpers import symbol +from .langhelpers import unbound_method_to_callable +from .langhelpers import walk_subclasses +from .langhelpers import warn +from .langhelpers import warn_exception +from .langhelpers import warn_limited +from .langhelpers import wrap_callable diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py new file mode 100644 index 0000000..8e21830 --- /dev/null +++ b/lib/sqlalchemy/util/_collections.py @@ -0,0 +1,1089 @@ +# util/_collections.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Collection classes and helpers.""" + +from __future__ import absolute_import + +import operator +import types +import weakref + +from .compat import binary_types +from .compat import collections_abc +from .compat import itertools_filterfalse +from .compat import py2k +from .compat import py37 +from .compat import string_types +from .compat import threading + + +EMPTY_SET = frozenset() + + +class ImmutableContainer(object): + def _immutable(self, *arg, **kw): + raise TypeError("%s object is immutable" % self.__class__.__name__) + + __delitem__ = __setitem__ = __setattr__ = _immutable + + +def _immutabledict_py_fallback(): + class immutabledict(ImmutableContainer, dict): + + clear = ( + pop + ) = popitem = setdefault = update = ImmutableContainer._immutable + + def __new__(cls, *args): + new = dict.__new__(cls) + dict.__init__(new, *args) + return new + + def __init__(self, *args): + pass + + def __reduce__(self): + return _immutabledict_reconstructor, (dict(self),) + + def union(self, __d=None): + if not __d: + return self + + new = dict.__new__(self.__class__) + dict.__init__(new, self) + dict.update(new, __d) + return new + + def _union_w_kw(self, __d=None, **kw): + # not sure if C version works correctly w/ this yet + if not __d and not kw: + return self + + new = dict.__new__(self.__class__) + dict.__init__(new, self) + if __d: + dict.update(new, __d) + dict.update(new, kw) + return new + + def merge_with(self, *dicts): + new = None + for d in dicts: + if d: + if new is None: + new = dict.__new__(self.__class__) + dict.__init__(new, self) + dict.update(new, d) + if new is None: + return self + + return new + + def __repr__(self): + return "immutabledict(%s)" % dict.__repr__(self) + + return immutabledict + + +try: + from sqlalchemy.cimmutabledict import immutabledict + + collections_abc.Mapping.register(immutabledict) + +except ImportError: + immutabledict = _immutabledict_py_fallback() + + def _immutabledict_reconstructor(*arg): + """do the pickle dance""" + return immutabledict(*arg) + + +def coerce_to_immutabledict(d): + if not d: + return EMPTY_DICT + elif isinstance(d, immutabledict): + return d + else: + return immutabledict(d) + + +EMPTY_DICT = immutabledict() + + +class FacadeDict(ImmutableContainer, dict): + """A dictionary that is not publicly mutable.""" + + clear = pop = popitem = setdefault = update = ImmutableContainer._immutable + + def __new__(cls, *args): + new = dict.__new__(cls) + return new + + def copy(self): + raise NotImplementedError( + "an immutabledict shouldn't need to be copied. use dict(d) " + "if you need a mutable dictionary." + ) + + def __reduce__(self): + return FacadeDict, (dict(self),) + + def _insert_item(self, key, value): + """insert an item into the dictionary directly.""" + dict.__setitem__(self, key, value) + + def __repr__(self): + return "FacadeDict(%s)" % dict.__repr__(self) + + +class Properties(object): + """Provide a __getattr__/__setattr__ interface over a dict.""" + + __slots__ = ("_data",) + + def __init__(self, data): + object.__setattr__(self, "_data", data) + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(list(self._data.values())) + + def __dir__(self): + return dir(super(Properties, self)) + [ + str(k) for k in self._data.keys() + ] + + def __add__(self, other): + return list(self) + list(other) + + def __setitem__(self, key, obj): + self._data[key] = obj + + def __getitem__(self, key): + return self._data[key] + + def __delitem__(self, key): + del self._data[key] + + def __setattr__(self, key, obj): + self._data[key] = obj + + def __getstate__(self): + return {"_data": self._data} + + def __setstate__(self, state): + object.__setattr__(self, "_data", state["_data"]) + + def __getattr__(self, key): + try: + return self._data[key] + except KeyError: + raise AttributeError(key) + + def __contains__(self, key): + return key in self._data + + def as_immutable(self): + """Return an immutable proxy for this :class:`.Properties`.""" + + return ImmutableProperties(self._data) + + def update(self, value): + self._data.update(value) + + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def keys(self): + return list(self._data) + + def values(self): + return list(self._data.values()) + + def items(self): + return list(self._data.items()) + + def has_key(self, key): + return key in self._data + + def clear(self): + self._data.clear() + + +class OrderedProperties(Properties): + """Provide a __getattr__/__setattr__ interface with an OrderedDict + as backing store.""" + + __slots__ = () + + def __init__(self): + Properties.__init__(self, OrderedDict()) + + +class ImmutableProperties(ImmutableContainer, Properties): + """Provide immutable dict/object attribute to an underlying dictionary.""" + + __slots__ = () + + +def _ordered_dictionary_sort(d, key=None): + """Sort an OrderedDict in-place.""" + + items = [(k, d[k]) for k in sorted(d, key=key)] + + d.clear() + + d.update(items) + + +if py37: + OrderedDict = dict + sort_dictionary = _ordered_dictionary_sort + +else: + # prevent sort_dictionary from being used against a plain dictionary + # for Python < 3.7 + + def sort_dictionary(d, key=None): + """Sort an OrderedDict in place.""" + + d._ordered_dictionary_sort(key=key) + + class OrderedDict(dict): + """Dictionary that maintains insertion order. + + Superseded by Python dict as of Python 3.7 + + """ + + __slots__ = ("_list",) + + def _ordered_dictionary_sort(self, key=None): + _ordered_dictionary_sort(self, key=key) + + def __reduce__(self): + return OrderedDict, (self.items(),) + + def __init__(self, ____sequence=None, **kwargs): + self._list = [] + if ____sequence is None: + if kwargs: + self.update(**kwargs) + else: + self.update(____sequence, **kwargs) + + def clear(self): + self._list = [] + dict.clear(self) + + def copy(self): + return self.__copy__() + + def __copy__(self): + return OrderedDict(self) + + def update(self, ____sequence=None, **kwargs): + if ____sequence is not None: + if hasattr(____sequence, "keys"): + for key in ____sequence.keys(): + self.__setitem__(key, ____sequence[key]) + else: + for key, value in ____sequence: + self[key] = value + if kwargs: + self.update(kwargs) + + def setdefault(self, key, value): + if key not in self: + self.__setitem__(key, value) + return value + else: + return self.__getitem__(key) + + def __iter__(self): + return iter(self._list) + + def keys(self): + return list(self) + + def values(self): + return [self[key] for key in self._list] + + def items(self): + return [(key, self[key]) for key in self._list] + + if py2k: + + def itervalues(self): + return iter(self.values()) + + def iterkeys(self): + return iter(self) + + def iteritems(self): + return iter(self.items()) + + def __setitem__(self, key, obj): + if key not in self: + try: + self._list.append(key) + except AttributeError: + # work around Python pickle loads() with + # dict subclass (seems to ignore __setstate__?) + self._list = [key] + dict.__setitem__(self, key, obj) + + def __delitem__(self, key): + dict.__delitem__(self, key) + self._list.remove(key) + + def pop(self, key, *default): + present = key in self + value = dict.pop(self, key, *default) + if present: + self._list.remove(key) + return value + + def popitem(self): + item = dict.popitem(self) + self._list.remove(item[0]) + return item + + +class OrderedSet(set): + def __init__(self, d=None): + set.__init__(self) + if d is not None: + self._list = unique_list(d) + set.update(self, self._list) + else: + self._list = [] + + def add(self, element): + if element not in self: + self._list.append(element) + set.add(self, element) + + def remove(self, element): + set.remove(self, element) + self._list.remove(element) + + def insert(self, pos, element): + if element not in self: + self._list.insert(pos, element) + set.add(self, element) + + def discard(self, element): + if element in self: + self._list.remove(element) + set.remove(self, element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __add__(self, other): + return self.union(other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, iterable): + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) + return self + + __ior__ = update + + def union(self, other): + result = self.__class__(self) + result.update(other) + return result + + __or__ = union + + def intersection(self, other): + other = set(other) + return self.__class__(a for a in self if a in other) + + __and__ = intersection + + def symmetric_difference(self, other): + other = set(other) + result = self.__class__(a for a in self if a not in other) + result.update(a for a in other if a not in self) + return result + + __xor__ = symmetric_difference + + def difference(self, other): + other = set(other) + return self.__class__(a for a in self if a not in other) + + __sub__ = difference + + def intersection_update(self, other): + other = set(other) + set.intersection_update(self, other) + self._list = [a for a in self._list if a in other] + return self + + __iand__ = intersection_update + + def symmetric_difference_update(self, other): + set.symmetric_difference_update(self, other) + self._list = [a for a in self._list if a in self] + self._list += [a for a in other._list if a in self] + return self + + __ixor__ = symmetric_difference_update + + def difference_update(self, other): + set.difference_update(self, other) + self._list = [a for a in self._list if a in self] + return self + + __isub__ = difference_update + + +class IdentitySet(object): + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + def __init__(self, iterable=None): + self._members = dict() + if iterable: + self.update(iterable) + + def add(self, value): + self._members[id(value)] = value + + def __contains__(self, value): + return id(value) in self._members + + def remove(self, value): + del self._members[id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError("pop from an empty set") + + def clear(self): + self._members.clear() + + def __cmp__(self, other): + raise TypeError("cannot compare sets using cmp()") + + def __eq__(self, other): + if isinstance(other, IdentitySet): + return self._members == other._members + else: + return False + + def __ne__(self, other): + if isinstance(other, IdentitySet): + return self._members != other._members + else: + return True + + def issubset(self, iterable): + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) > len(other): + return False + for m in itertools_filterfalse( + other._members.__contains__, iter(self._members.keys()) + ): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + def issuperset(self, iterable): + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) < len(other): + return False + + for m in itertools_filterfalse( + self._members.__contains__, iter(other._members.keys()) + ): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + def union(self, iterable): + result = self.__class__() + members = self._members + result._members.update(members) + result._members.update((id(obj), obj) for obj in iterable) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.union(other) + + def update(self, iterable): + self._members.update((id(obj), obj) for obj in iterable) + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + def difference(self, iterable): + result = self.__class__() + members = self._members + if isinstance(iterable, self.__class__): + other = set(iterable._members.keys()) + else: + other = {id(obj) for obj in iterable} + result._members.update( + ((k, v) for k, v in members.items() if k not in other) + ) + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.difference(other) + + def difference_update(self, iterable): + self._members = self.difference(iterable)._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + def intersection(self, iterable): + result = self.__class__() + members = self._members + if isinstance(iterable, self.__class__): + other = set(iterable._members.keys()) + else: + other = {id(obj) for obj in iterable} + result._members.update( + (k, v) for k, v in members.items() if k in other + ) + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.intersection(other) + + def intersection_update(self, iterable): + self._members = self.intersection(iterable)._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + def symmetric_difference(self, iterable): + result = self.__class__() + members = self._members + if isinstance(iterable, self.__class__): + other = iterable._members + else: + other = {id(obj): obj for obj in iterable} + result._members.update( + ((k, v) for k, v in members.items() if k not in other) + ) + result._members.update( + ((k, v) for k, v in other.items() if k not in members) + ) + return result + + def __xor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference_update(self, iterable): + self._members = self.symmetric_difference(iterable)._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + def copy(self): + return type(self)(iter(self._members.values())) + + __copy__ = copy + + def __len__(self): + return len(self._members) + + def __iter__(self): + return iter(self._members.values()) + + def __hash__(self): + raise TypeError("set objects are unhashable") + + def __repr__(self): + return "%s(%r)" % (type(self).__name__, list(self._members.values())) + + +class WeakSequence(object): + def __init__(self, __elements=()): + # adapted from weakref.WeakKeyDictionary, prevent reference + # cycles in the collection itself + def _remove(item, selfref=weakref.ref(self)): + self = selfref() + if self is not None: + self._storage.remove(item) + + self._remove = _remove + self._storage = [ + weakref.ref(element, _remove) for element in __elements + ] + + def append(self, item): + self._storage.append(weakref.ref(item, self._remove)) + + def __len__(self): + return len(self._storage) + + def __iter__(self): + return ( + obj for obj in (ref() for ref in self._storage) if obj is not None + ) + + def __getitem__(self, index): + try: + obj = self._storage[index] + except KeyError: + raise IndexError("Index %s out of range" % index) + else: + return obj() + + +class OrderedIdentitySet(IdentitySet): + def __init__(self, iterable=None): + IdentitySet.__init__(self) + self._members = OrderedDict() + if iterable: + for o in iterable: + self.add(o) + + +class PopulateDict(dict): + """A dict which populates missing values via a creation function. + + Note the creation function takes a key, unlike + collections.defaultdict. + + """ + + def __init__(self, creator): + self.creator = creator + + def __missing__(self, key): + self[key] = val = self.creator(key) + return val + + +class WeakPopulateDict(dict): + """Like PopulateDict, but assumes a self + a method and does not create + a reference cycle. + + """ + + def __init__(self, creator_method): + self.creator = creator_method.__func__ + weakself = creator_method.__self__ + self.weakself = weakref.ref(weakself) + + def __missing__(self, key): + self[key] = val = self.creator(self.weakself(), key) + return val + + +# Define collections that are capable of storing +# ColumnElement objects as hashable keys/elements. +# At this point, these are mostly historical, things +# used to be more complicated. +column_set = set +column_dict = dict +ordered_column_set = OrderedSet + + +_getters = PopulateDict(operator.itemgetter) + +_property_getters = PopulateDict( + lambda idx: property(operator.itemgetter(idx)) +) + + +def unique_list(seq, hashfunc=None): + seen = set() + seen_add = seen.add + if not hashfunc: + return [x for x in seq if x not in seen and not seen_add(x)] + else: + return [ + x + for x in seq + if hashfunc(x) not in seen and not seen_add(hashfunc(x)) + ] + + +class UniqueAppender(object): + """Appends items to a collection ensuring uniqueness. + + Additional appends() of the same object are ignored. Membership is + determined by identity (``is a``) not equality (``==``). + """ + + def __init__(self, data, via=None): + self.data = data + self._unique = {} + if via: + self._data_appender = getattr(data, via) + elif hasattr(data, "append"): + self._data_appender = data.append + elif hasattr(data, "add"): + self._data_appender = data.add + + def append(self, item): + id_ = id(item) + if id_ not in self._unique: + self._data_appender(item) + self._unique[id_] = True + + def __iter__(self): + return iter(self.data) + + +def coerce_generator_arg(arg): + if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): + return list(arg[0]) + else: + return arg + + +def to_list(x, default=None): + if x is None: + return default + if not isinstance(x, collections_abc.Iterable) or isinstance( + x, string_types + binary_types + ): + return [x] + elif isinstance(x, list): + return x + else: + return list(x) + + +def has_intersection(set_, iterable): + r"""return True if any items of set\_ are present in iterable. + + Goes through special effort to ensure __hash__ is not called + on items in iterable that don't support it. + + """ + # TODO: optimize, write in C, etc. + return bool(set_.intersection([i for i in iterable if i.__hash__])) + + +def to_set(x): + if x is None: + return set() + if not isinstance(x, set): + return set(to_list(x)) + else: + return x + + +def to_column_set(x): + if x is None: + return column_set() + if not isinstance(x, column_set): + return column_set(to_list(x)) + else: + return x + + +def update_copy(d, _new=None, **kw): + """Copy the given dict and update with the given values.""" + + d = d.copy() + if _new: + d.update(_new) + d.update(**kw) + return d + + +def flatten_iterator(x): + """Given an iterator of which further sub-elements may also be + iterators, flatten the sub-elements into a single iterator. + + """ + for elem in x: + if not isinstance(elem, str) and hasattr(elem, "__iter__"): + for y in flatten_iterator(elem): + yield y + else: + yield elem + + +class LRUCache(dict): + """Dictionary with 'squishy' removal of least + recently used items. + + Note that either get() or [] should be used here, but + generally its not safe to do an "in" check first as the dictionary + can change subsequent to that call. + + """ + + __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex" + + def __init__(self, capacity=100, threshold=0.5, size_alert=None): + self.capacity = capacity + self.threshold = threshold + self.size_alert = size_alert + self._counter = 0 + self._mutex = threading.Lock() + + def _inc_counter(self): + self._counter += 1 + return self._counter + + def get(self, key, default=None): + item = dict.get(self, key, default) + if item is not default: + item[2] = self._inc_counter() + return item[1] + else: + return default + + def __getitem__(self, key): + item = dict.__getitem__(self, key) + item[2] = self._inc_counter() + return item[1] + + def values(self): + return [i[1] for i in dict.values(self)] + + def setdefault(self, key, value): + if key in self: + return self[key] + else: + self[key] = value + return value + + def __setitem__(self, key, value): + item = dict.get(self, key) + if item is None: + item = [key, value, self._inc_counter()] + dict.__setitem__(self, key, item) + else: + item[1] = value + self._manage_size() + + @property + def size_threshold(self): + return self.capacity + self.capacity * self.threshold + + def _manage_size(self): + if not self._mutex.acquire(False): + return + try: + size_alert = bool(self.size_alert) + while len(self) > self.capacity + self.capacity * self.threshold: + if size_alert: + size_alert = False + self.size_alert(self) + by_counter = sorted( + dict.values(self), key=operator.itemgetter(2), reverse=True + ) + for item in by_counter[self.capacity :]: + try: + del self[item[0]] + except KeyError: + # deleted elsewhere; skip + continue + finally: + self._mutex.release() + + +class ScopedRegistry(object): + """A Registry that can store one or multiple instances of a single + class on the basis of a "scope" function. + + The object implements ``__call__`` as the "getter", so by + calling ``myregistry()`` the contained object is returned + for the current scope. + + :param createfunc: + a callable that returns a new object to be placed in the registry + + :param scopefunc: + a callable that will return a key to store/retrieve an object. + """ + + def __init__(self, createfunc, scopefunc): + """Construct a new :class:`.ScopedRegistry`. + + :param createfunc: A creation function that will generate + a new value for the current scope, if none is present. + + :param scopefunc: A function that returns a hashable + token representing the current scope (such as, current + thread identifier). + + """ + self.createfunc = createfunc + self.scopefunc = scopefunc + self.registry = {} + + def __call__(self): + key = self.scopefunc() + try: + return self.registry[key] + except KeyError: + return self.registry.setdefault(key, self.createfunc()) + + def has(self): + """Return True if an object is present in the current scope.""" + + return self.scopefunc() in self.registry + + def set(self, obj): + """Set the value for the current scope.""" + + self.registry[self.scopefunc()] = obj + + def clear(self): + """Clear the current scope, if any.""" + + try: + del self.registry[self.scopefunc()] + except KeyError: + pass + + +class ThreadLocalRegistry(ScopedRegistry): + """A :class:`.ScopedRegistry` that uses a ``threading.local()`` + variable for storage. + + """ + + def __init__(self, createfunc): + self.createfunc = createfunc + self.registry = threading.local() + + def __call__(self): + try: + return self.registry.value + except AttributeError: + val = self.registry.value = self.createfunc() + return val + + def has(self): + return hasattr(self.registry, "value") + + def set(self, obj): + self.registry.value = obj + + def clear(self): + try: + del self.registry.value + except AttributeError: + pass + + +def has_dupes(sequence, target): + """Given a sequence and search object, return True if there's more + than one, False if zero or one of them. + + + """ + # compare to .index version below, this version introduces less function + # overhead and is usually the same speed. At 15000 items (way bigger than + # a relationship-bound collection in memory usually is) it begins to + # fall behind the other version only by microseconds. + c = 0 + for item in sequence: + if item is target: + c += 1 + if c > 1: + return True + return False + + +# .index version. the two __contains__ calls as well +# as .index() and isinstance() slow this down. +# def has_dupes(sequence, target): +# if target not in sequence: +# return False +# elif not isinstance(sequence, collections_abc.Sequence): +# return False +# +# idx = sequence.index(target) +# return target in sequence[idx + 1:] diff --git a/lib/sqlalchemy/util/_compat_py3k.py b/lib/sqlalchemy/util/_compat_py3k.py new file mode 100644 index 0000000..ce659a4 --- /dev/null +++ b/lib/sqlalchemy/util/_compat_py3k.py @@ -0,0 +1,67 @@ +# util/_compat_py3k.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from functools import wraps + +# vendored from py3.7 + + +class _AsyncGeneratorContextManager: + """Helper for @asynccontextmanager.""" + + def __init__(self, func, args, kwds): + self.gen = func(*args, **kwds) + self.func, self.args, self.kwds = func, args, kwds + doc = getattr(func, "__doc__", None) + if doc is None: + doc = type(self).__doc__ + self.__doc__ = doc + + async def __aenter__(self): + try: + return await self.gen.__anext__() + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None + + async def __aexit__(self, typ, value, traceback): + if typ is None: + try: + await self.gen.__anext__() + except StopAsyncIteration: + return + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + value = typ() + # See _GeneratorContextManager.__exit__ for comments on subtleties + # in this implementation + try: + await self.gen.athrow(typ, value, traceback) + raise RuntimeError("generator didn't stop after athrow()") + except StopAsyncIteration as exc: + return exc is not value + except RuntimeError as exc: + if exc is value: + return False + if isinstance(value, (StopIteration, StopAsyncIteration)): + if exc.__cause__ is value: + return False + raise + except BaseException as exc: + if exc is not value: + raise + + +# using the vendored version in all cases at the moment to establish +# full test coverage +def asynccontextmanager(func): + @wraps(func) + def helper(*args, **kwds): + return _AsyncGeneratorContextManager(func, args, kwds) + + return helper diff --git a/lib/sqlalchemy/util/_concurrency_py3k.py b/lib/sqlalchemy/util/_concurrency_py3k.py new file mode 100644 index 0000000..0b12834 --- /dev/null +++ b/lib/sqlalchemy/util/_concurrency_py3k.py @@ -0,0 +1,194 @@ +# util/_concurrency_py3k.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import asyncio +import sys +from typing import Any +from typing import Callable +from typing import Coroutine + +import greenlet + +from . import compat +from .langhelpers import memoized_property +from .. import exc + +# If greenlet.gr_context is present in current version of greenlet, +# it will be set with the current context on creation. +# Refs: https://github.com/python-greenlet/greenlet/pull/198 +_has_gr_context = hasattr(greenlet.getcurrent(), "gr_context") + + +def is_exit_exception(e): + # note asyncio.CancelledError is already BaseException + # so was an exit exception in any case + return not isinstance(e, Exception) or isinstance( + e, (asyncio.TimeoutError, asyncio.CancelledError) + ) + + +# implementation based on snaury gist at +# https://gist.github.com/snaury/202bf4f22c41ca34e56297bae5f33fef +# Issue for context: https://github.com/python-greenlet/greenlet/issues/173 + + +class _AsyncIoGreenlet(greenlet.greenlet): + def __init__(self, fn, driver): + greenlet.greenlet.__init__(self, fn, driver) + self.driver = driver + if _has_gr_context: + self.gr_context = driver.gr_context + + +def await_only(awaitable: Coroutine) -> Any: + """Awaits an async function in a sync method. + + The sync method must be inside a :func:`greenlet_spawn` context. + :func:`await_only` calls cannot be nested. + + :param awaitable: The coroutine to call. + + """ + # this is called in the context greenlet while running fn + current = greenlet.getcurrent() + if not isinstance(current, _AsyncIoGreenlet): + raise exc.MissingGreenlet( + "greenlet_spawn has not been called; can't call await_only() " + "here. Was IO attempted in an unexpected place?" + ) + + # returns the control to the driver greenlet passing it + # a coroutine to run. Once the awaitable is done, the driver greenlet + # switches back to this greenlet with the result of awaitable that is + # then returned to the caller (or raised as error) + return current.driver.switch(awaitable) + + +def await_fallback(awaitable: Coroutine) -> Any: + """Awaits an async function in a sync method. + + The sync method must be inside a :func:`greenlet_spawn` context. + :func:`await_fallback` calls cannot be nested. + + :param awaitable: The coroutine to call. + + """ + # this is called in the context greenlet while running fn + current = greenlet.getcurrent() + if not isinstance(current, _AsyncIoGreenlet): + loop = get_event_loop() + if loop.is_running(): + raise exc.MissingGreenlet( + "greenlet_spawn has not been called and asyncio event " + "loop is already running; can't call await_fallback() here. " + "Was IO attempted in an unexpected place?" + ) + return loop.run_until_complete(awaitable) + + return current.driver.switch(awaitable) + + +async def greenlet_spawn( + fn: Callable, *args, _require_await=False, **kwargs +) -> Any: + """Runs a sync function ``fn`` in a new greenlet. + + The sync function can then use :func:`await_only` to wait for async + functions. + + :param fn: The sync callable to call. + :param \\*args: Positional arguments to pass to the ``fn`` callable. + :param \\*\\*kwargs: Keyword arguments to pass to the ``fn`` callable. + """ + + context = _AsyncIoGreenlet(fn, greenlet.getcurrent()) + # runs the function synchronously in gl greenlet. If the execution + # is interrupted by await_only, context is not dead and result is a + # coroutine to wait. If the context is dead the function has + # returned, and its result can be returned. + switch_occurred = False + try: + result = context.switch(*args, **kwargs) + while not context.dead: + switch_occurred = True + try: + # wait for a coroutine from await_only and then return its + # result back to it. + value = await result + except BaseException: + # this allows an exception to be raised within + # the moderated greenlet so that it can continue + # its expected flow. + result = context.throw(*sys.exc_info()) + else: + result = context.switch(value) + finally: + # clean up to avoid cycle resolution by gc + del context.driver + if _require_await and not switch_occurred: + raise exc.AwaitRequired( + "The current operation required an async execution but none was " + "detected. This will usually happen when using a non compatible " + "DBAPI driver. Please ensure that an async DBAPI is used." + ) + return result + + +class AsyncAdaptedLock: + @memoized_property + def mutex(self): + # there should not be a race here for coroutines creating the + # new lock as we are not using await, so therefore no concurrency + return asyncio.Lock() + + def __enter__(self): + # await is used to acquire the lock only after the first calling + # coroutine has created the mutex. + await_fallback(self.mutex.acquire()) + return self + + def __exit__(self, *arg, **kw): + self.mutex.release() + + +def _util_async_run_coroutine_function(fn, *args, **kwargs): + """for test suite/ util only""" + + loop = get_event_loop() + if loop.is_running(): + raise Exception( + "for async run coroutine we expect that no greenlet or event " + "loop is running when we start out" + ) + return loop.run_until_complete(fn(*args, **kwargs)) + + +def _util_async_run(fn, *args, **kwargs): + """for test suite/ util only""" + + loop = get_event_loop() + if not loop.is_running(): + return loop.run_until_complete(greenlet_spawn(fn, *args, **kwargs)) + else: + # allow for a wrapped test function to call another + assert isinstance(greenlet.getcurrent(), _AsyncIoGreenlet) + return fn(*args, **kwargs) + + +def get_event_loop(): + """vendor asyncio.get_event_loop() for python 3.7 and above. + + Python 3.10 deprecates get_event_loop() as a standalone. + + """ + if compat.py37: + try: + return asyncio.get_running_loop() + except RuntimeError: + return asyncio.get_event_loop_policy().get_event_loop() + else: + return asyncio.get_event_loop() diff --git a/lib/sqlalchemy/util/_preloaded.py b/lib/sqlalchemy/util/_preloaded.py new file mode 100644 index 0000000..1803de4 --- /dev/null +++ b/lib/sqlalchemy/util/_preloaded.py @@ -0,0 +1,68 @@ +# util/_preloaded.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""supplies the "preloaded" registry to resolve circular module imports at +runtime. + +""" + +import sys + +from . import compat + + +class _ModuleRegistry: + """Registry of modules to load in a package init file. + + To avoid potential thread safety issues for imports that are deferred + in a function, like https://bugs.python.org/issue38884, these modules + are added to the system module cache by importing them after the packages + has finished initialization. + + A global instance is provided under the name :attr:`.preloaded`. Use + the function :func:`.preload_module` to register modules to load and + :meth:`.import_prefix` to load all the modules that start with the + given path. + + While the modules are loaded in the global module cache, it's advisable + to access them using :attr:`.preloaded` to ensure that it was actually + registered. Each registered module is added to the instance ``__dict__`` + in the form `<package>_<module>`, omitting ``sqlalchemy`` from the package + name. Example: ``sqlalchemy.sql.util`` becomes ``preloaded.sql_util``. + """ + + def __init__(self, prefix="sqlalchemy."): + self.module_registry = set() + self.prefix = prefix + + def preload_module(self, *deps): + """Adds the specified modules to the list to load. + + This method can be used both as a normal function and as a decorator. + No change is performed to the decorated object. + """ + self.module_registry.update(deps) + return lambda fn: fn + + def import_prefix(self, path): + """Resolve all the modules in the registry that start with the + specified path. + """ + for module in self.module_registry: + if self.prefix: + key = module.split(self.prefix)[-1].replace(".", "_") + else: + key = module + if ( + not path or module.startswith(path) + ) and key not in self.__dict__: + compat.import_(module, globals(), locals()) + self.__dict__[key] = sys.modules[module] + + +preloaded = _ModuleRegistry() +preload_module = preloaded.preload_module diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py new file mode 100644 index 0000000..21a9491 --- /dev/null +++ b/lib/sqlalchemy/util/compat.py @@ -0,0 +1,632 @@ +# util/compat.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Handle Python version/platform incompatibilities.""" + +import collections +import contextlib +import inspect +import operator +import platform +import sys + +py311 = sys.version_info >= (3, 11) +py39 = sys.version_info >= (3, 9) +py38 = sys.version_info >= (3, 8) +py37 = sys.version_info >= (3, 7) +py3k = sys.version_info >= (3, 0) +py2k = sys.version_info < (3, 0) +pypy = platform.python_implementation() == "PyPy" + + +cpython = platform.python_implementation() == "CPython" +win32 = sys.platform.startswith("win") +osx = sys.platform.startswith("darwin") +arm = "aarch" in platform.machine().lower() + +has_refcount_gc = bool(cpython) + +contextmanager = contextlib.contextmanager +dottedgetter = operator.attrgetter +namedtuple = collections.namedtuple +next = next # noqa + +FullArgSpec = collections.namedtuple( + "FullArgSpec", + [ + "args", + "varargs", + "varkw", + "defaults", + "kwonlyargs", + "kwonlydefaults", + "annotations", + ], +) + + +class nullcontext(object): + """Context manager that does no additional processing. + + Vendored from Python 3.7. + + """ + + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass + + +try: + import threading +except ImportError: + import dummy_threading as threading # noqa + + +def inspect_getfullargspec(func): + """Fully vendored version of getfullargspec from Python 3.3.""" + + if inspect.ismethod(func): + func = func.__func__ + if not inspect.isfunction(func): + raise TypeError("{!r} is not a Python function".format(func)) + + co = func.__code__ + if not inspect.iscode(co): + raise TypeError("{!r} is not a code object".format(co)) + + nargs = co.co_argcount + names = co.co_varnames + nkwargs = co.co_kwonlyargcount if py3k else 0 + args = list(names[:nargs]) + kwonlyargs = list(names[nargs : nargs + nkwargs]) + + nargs += nkwargs + varargs = None + if co.co_flags & inspect.CO_VARARGS: + varargs = co.co_varnames[nargs] + nargs = nargs + 1 + varkw = None + if co.co_flags & inspect.CO_VARKEYWORDS: + varkw = co.co_varnames[nargs] + + return FullArgSpec( + args, + varargs, + varkw, + func.__defaults__, + kwonlyargs, + func.__kwdefaults__ if py3k else None, + func.__annotations__ if py3k else {}, + ) + + +if py38: + from importlib import metadata as importlib_metadata +else: + import importlib_metadata # noqa + + +def importlib_metadata_get(group): + ep = importlib_metadata.entry_points() + if hasattr(ep, "select"): + return ep.select(group=group) + else: + return ep.get(group, ()) + + +if py3k: + import base64 + import builtins + import configparser + import itertools + import pickle + + from functools import reduce + from io import BytesIO as byte_buffer + from io import StringIO + from itertools import zip_longest + from time import perf_counter + from urllib.parse import ( + quote_plus, + unquote_plus, + parse_qsl, + quote, + unquote, + ) + + string_types = (str,) + binary_types = (bytes,) + binary_type = bytes + text_type = str + int_types = (int,) + iterbytes = iter + long_type = int + + itertools_filterfalse = itertools.filterfalse + itertools_filter = filter + itertools_imap = map + + exec_ = getattr(builtins, "exec") + import_ = getattr(builtins, "__import__") + print_ = getattr(builtins, "print") + + def b(s): + return s.encode("latin-1") + + def b64decode(x): + return base64.b64decode(x.encode("ascii")) + + def b64encode(x): + return base64.b64encode(x).decode("ascii") + + def decode_backslashreplace(text, encoding): + return text.decode(encoding, errors="backslashreplace") + + def cmp(a, b): + return (a > b) - (a < b) + + def raise_( + exception, with_traceback=None, replace_context=None, from_=False + ): + r"""implement "raise" with cause support. + + :param exception: exception to raise + :param with_traceback: will call exception.with_traceback() + :param replace_context: an as-yet-unsupported feature. This is + an exception object which we are "replacing", e.g., it's our + "cause" but we don't want it printed. Basically just what + ``__suppress_context__`` does but we don't want to suppress + the enclosing context, if any. So for now we make it the + cause. + :param from\_: the cause. this actually sets the cause and doesn't + hope to hide it someday. + + """ + if with_traceback is not None: + exception = exception.with_traceback(with_traceback) + + if from_ is not False: + exception.__cause__ = from_ + elif replace_context is not None: + # no good solution here, we would like to have the exception + # have only the context of replace_context.__context__ so that the + # intermediary exception does not change, but we can't figure + # that out. + exception.__cause__ = replace_context + + try: + raise exception + finally: + # credit to + # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/ + # as the __traceback__ object creates a cycle + del exception, replace_context, from_, with_traceback + + def u(s): + return s + + def ue(s): + return s + + from typing import TYPE_CHECKING + + # Unused. Kept for backwards compatibility. + callable = callable # noqa + + from abc import ABC + + def _qualname(fn): + return fn.__qualname__ + + +else: + import base64 + import ConfigParser as configparser # noqa + import itertools + + from StringIO import StringIO # noqa + from cStringIO import StringIO as byte_buffer # noqa + from itertools import izip_longest as zip_longest # noqa + from time import clock as perf_counter # noqa + from urllib import quote # noqa + from urllib import quote_plus # noqa + from urllib import unquote # noqa + from urllib import unquote_plus # noqa + from urlparse import parse_qsl # noqa + + from abc import ABCMeta + + class ABC(object): + __metaclass__ = ABCMeta + + try: + import cPickle as pickle + except ImportError: + import pickle # noqa + + string_types = (basestring,) # noqa + binary_types = (bytes,) + binary_type = str + text_type = unicode # noqa + int_types = int, long # noqa + long_type = long # noqa + + callable = callable # noqa + cmp = cmp # noqa + reduce = reduce # noqa + + b64encode = base64.b64encode + b64decode = base64.b64decode + + itertools_filterfalse = itertools.ifilterfalse + itertools_filter = itertools.ifilter + itertools_imap = itertools.imap + + def b(s): + return s + + def exec_(func_text, globals_, lcl=None): + if lcl is None: + exec("exec func_text in globals_") + else: + exec("exec func_text in globals_, lcl") + + def iterbytes(buf): + return (ord(byte) for byte in buf) + + def import_(*args): + if len(args) == 4: + args = args[0:3] + ([str(arg) for arg in args[3]],) + return __import__(*args) + + def print_(*args, **kwargs): + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + for arg in enumerate(args): + if not isinstance(arg, basestring): # noqa + arg = str(arg) + fp.write(arg) + + def u(s): + # this differs from what six does, which doesn't support non-ASCII + # strings - we only use u() with + # literal source strings, and all our source files with non-ascii + # in them (all are tests) are utf-8 encoded. + return unicode(s, "utf-8") # noqa + + def ue(s): + return unicode(s, "unicode_escape") # noqa + + def decode_backslashreplace(text, encoding): + try: + return text.decode(encoding) + except UnicodeDecodeError: + # regular "backslashreplace" for an incompatible encoding raises: + # "TypeError: don't know how to handle UnicodeDecodeError in + # error callback" + return repr(text)[1:-1].decode() + + def safe_bytestring(text): + # py2k only + if not isinstance(text, string_types): + return unicode(text).encode( # noqa: F821 + "ascii", errors="backslashreplace" + ) + elif isinstance(text, unicode): # noqa: F821 + return text.encode("ascii", errors="backslashreplace") + else: + return text + + exec( + "def raise_(exception, with_traceback=None, replace_context=None, " + "from_=False):\n" + " if with_traceback:\n" + " raise type(exception), exception, with_traceback\n" + " else:\n" + " raise exception\n" + ) + + TYPE_CHECKING = False + + def _qualname(meth): + """return __qualname__ equivalent for a method on a class""" + + for cls in meth.im_class.__mro__: + if meth.__name__ in cls.__dict__: + break + else: + return meth.__name__ + + return "%s.%s" % (cls.__name__, meth.__name__) + + +if py3k: + + def _formatannotation(annotation, base_module=None): + """vendored from python 3.7""" + + if getattr(annotation, "__module__", None) == "typing": + return repr(annotation).replace("typing.", "") + if isinstance(annotation, type): + if annotation.__module__ in ("builtins", base_module): + return annotation.__qualname__ + return annotation.__module__ + "." + annotation.__qualname__ + return repr(annotation) + + def inspect_formatargspec( + args, + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=(), + kwonlydefaults={}, + annotations={}, + formatarg=str, + formatvarargs=lambda name: "*" + name, + formatvarkw=lambda name: "**" + name, + formatvalue=lambda value: "=" + repr(value), + formatreturns=lambda text: " -> " + text, + formatannotation=_formatannotation, + ): + """Copy formatargspec from python 3.7 standard library. + + Python 3 has deprecated formatargspec and requested that Signature + be used instead, however this requires a full reimplementation + of formatargspec() in terms of creating Parameter objects and such. + Instead of introducing all the object-creation overhead and having + to reinvent from scratch, just copy their compatibility routine. + + Ultimately we would need to rewrite our "decorator" routine completely + which is not really worth it right now, until all Python 2.x support + is dropped. + + """ + + kwonlydefaults = kwonlydefaults or {} + annotations = annotations or {} + + def formatargandannotation(arg): + result = formatarg(arg) + if arg in annotations: + result += ": " + formatannotation(annotations[arg]) + return result + + specs = [] + if defaults: + firstdefault = len(args) - len(defaults) + for i, arg in enumerate(args): + spec = formatargandannotation(arg) + if defaults and i >= firstdefault: + spec = spec + formatvalue(defaults[i - firstdefault]) + specs.append(spec) + + if varargs is not None: + specs.append(formatvarargs(formatargandannotation(varargs))) + else: + if kwonlyargs: + specs.append("*") + + if kwonlyargs: + for kwonlyarg in kwonlyargs: + spec = formatargandannotation(kwonlyarg) + if kwonlydefaults and kwonlyarg in kwonlydefaults: + spec += formatvalue(kwonlydefaults[kwonlyarg]) + specs.append(spec) + + if varkw is not None: + specs.append(formatvarkw(formatargandannotation(varkw))) + + result = "(" + ", ".join(specs) + ")" + if "return" in annotations: + result += formatreturns(formatannotation(annotations["return"])) + return result + + +else: + from inspect import formatargspec as _inspect_formatargspec + + def inspect_formatargspec(*spec, **kw): + # convert for a potential FullArgSpec from compat.getfullargspec() + return _inspect_formatargspec(*spec[0:4], **kw) # noqa + + +# Fix deprecation of accessing ABCs straight from collections module +# (which will stop working in 3.8). +if py3k: + import collections.abc as collections_abc +else: + import collections as collections_abc # noqa + + +if py37: + import dataclasses + + def dataclass_fields(cls): + """Return a sequence of all dataclasses.Field objects associated + with a class.""" + + if dataclasses.is_dataclass(cls): + return dataclasses.fields(cls) + else: + return [] + + def local_dataclass_fields(cls): + """Return a sequence of all dataclasses.Field objects associated with + a class, excluding those that originate from a superclass.""" + + if dataclasses.is_dataclass(cls): + super_fields = set() + for sup in cls.__bases__: + super_fields.update(dataclass_fields(sup)) + return [ + f for f in dataclasses.fields(cls) if f not in super_fields + ] + else: + return [] + + +else: + + def dataclass_fields(cls): + return [] + + def local_dataclass_fields(cls): + return [] + + +def raise_from_cause(exception, exc_info=None): + r"""legacy. use raise\_()""" + + if exc_info is None: + exc_info = sys.exc_info() + exc_type, exc_value, exc_tb = exc_info + cause = exc_value if exc_value is not exception else None + reraise(type(exception), exception, tb=exc_tb, cause=cause) + + +def reraise(tp, value, tb=None, cause=None): + r"""legacy. use raise\_()""" + + raise_(value, with_traceback=tb, from_=cause) + + +def with_metaclass(meta, *bases, **kw): + """Create a base class with a metaclass. + + Drops the middle class upon creation. + + Source: https://lucumr.pocoo.org/2013/5/21/porting-to-python-3-redux/ + + """ + + class metaclass(meta): + __call__ = type.__call__ + __init__ = type.__init__ + + def __new__(cls, name, this_bases, d): + if this_bases is None: + cls = type.__new__(cls, name, (), d) + else: + cls = meta(name, bases, d) + + if hasattr(cls, "__init_subclass__") and hasattr( + cls.__init_subclass__, "__func__" + ): + cls.__init_subclass__.__func__(cls, **kw) + return cls + + return metaclass("temporary_class", None, {}) + + +if py3k: + from datetime import timezone +else: + from datetime import datetime + from datetime import timedelta + from datetime import tzinfo + + class timezone(tzinfo): + """Minimal port of python 3 timezone object""" + + __slots__ = "_offset" + + def __init__(self, offset): + if not isinstance(offset, timedelta): + raise TypeError("offset must be a timedelta") + if not self._minoffset <= offset <= self._maxoffset: + raise ValueError( + "offset must be a timedelta " + "strictly between -timedelta(hours=24) and " + "timedelta(hours=24)." + ) + self._offset = offset + + def __eq__(self, other): + if type(other) != timezone: + return False + return self._offset == other._offset + + def __hash__(self): + return hash(self._offset) + + def __repr__(self): + return "sqlalchemy.util.%s(%r)" % ( + self.__class__.__name__, + self._offset, + ) + + def __str__(self): + return self.tzname(None) + + def utcoffset(self, dt): + return self._offset + + def tzname(self, dt): + return self._name_from_offset(self._offset) + + def dst(self, dt): + return None + + def fromutc(self, dt): + if isinstance(dt, datetime): + if dt.tzinfo is not self: + raise ValueError("fromutc: dt.tzinfo " "is not self") + return dt + self._offset + raise TypeError( + "fromutc() argument must be a datetime instance" " or None" + ) + + @staticmethod + def _timedelta_to_microseconds(timedelta): + """backport of timedelta._to_microseconds()""" + return ( + timedelta.days * (24 * 3600) + timedelta.seconds + ) * 1000000 + timedelta.microseconds + + @staticmethod + def _divmod_timedeltas(a, b): + """backport of timedelta.__divmod__""" + + q, r = divmod( + timezone._timedelta_to_microseconds(a), + timezone._timedelta_to_microseconds(b), + ) + return q, timedelta(0, 0, r) + + @staticmethod + def _name_from_offset(delta): + if not delta: + return "UTC" + if delta < timedelta(0): + sign = "-" + delta = -delta + else: + sign = "+" + hours, rest = timezone._divmod_timedeltas( + delta, timedelta(hours=1) + ) + minutes, rest = timezone._divmod_timedeltas( + rest, timedelta(minutes=1) + ) + result = "UTC%s%02d:%02d" % (sign, hours, minutes) + if rest.seconds: + result += ":%02d" % (rest.seconds,) + if rest.microseconds: + result += ".%06d" % (rest.microseconds,) + return result + + _maxoffset = timedelta(hours=23, minutes=59) + _minoffset = -_maxoffset + + timezone.utc = timezone(timedelta(0)) diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py new file mode 100644 index 0000000..e900b43 --- /dev/null +++ b/lib/sqlalchemy/util/concurrency.py @@ -0,0 +1,73 @@ +# util/concurrency.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from . import compat + +have_greenlet = False +greenlet_error = None + +if compat.py3k: + try: + import greenlet # noqa: F401 + except ImportError as e: + greenlet_error = str(e) + else: + have_greenlet = True + from ._concurrency_py3k import await_only + from ._concurrency_py3k import await_fallback + from ._concurrency_py3k import greenlet_spawn + from ._concurrency_py3k import is_exit_exception + from ._concurrency_py3k import AsyncAdaptedLock + from ._concurrency_py3k import _util_async_run # noqa: F401 + from ._concurrency_py3k import ( + _util_async_run_coroutine_function, + ) # noqa: F401, E501 + from ._concurrency_py3k import asyncio # noqa: F401 + + # does not need greennlet, just Python 3 + from ._compat_py3k import asynccontextmanager # noqa: F401 + +if not have_greenlet: + + asyncio = None # noqa: F811 + + def _not_implemented(): + # this conditional is to prevent pylance from considering + # greenlet_spawn() etc as "no return" and dimming out code below it + if have_greenlet: + return None + + if not compat.py3k: + raise ValueError("Cannot use this function in py2.") + else: + raise ValueError( + "the greenlet library is required to use this function." + " %s" % greenlet_error + if greenlet_error + else "" + ) + + def is_exit_exception(e): # noqa: F811 + return not isinstance(e, Exception) + + def await_only(thing): # noqa: F811 + _not_implemented() + + def await_fallback(thing): # noqa: F811 + return thing + + def greenlet_spawn(fn, *args, **kw): # noqa: F811 + _not_implemented() + + def AsyncAdaptedLock(*args, **kw): # noqa: F811 + _not_implemented() + + def _util_async_run(fn, *arg, **kw): # noqa: F811 + return fn(*arg, **kw) + + def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa: F811 + _not_implemented() diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py new file mode 100644 index 0000000..b61516d --- /dev/null +++ b/lib/sqlalchemy/util/deprecations.py @@ -0,0 +1,417 @@ +# util/deprecations.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Helpers related to deprecation of functions, methods, classes, other +functionality.""" + +import os +import re + +from . import compat +from .langhelpers import _hash_limit_string +from .langhelpers import _warnings_warn +from .langhelpers import decorator +from .langhelpers import inject_docstring_text +from .langhelpers import inject_param_text +from .. import exc + + +SQLALCHEMY_WARN_20 = False + +if os.getenv("SQLALCHEMY_WARN_20", "false").lower() in ("true", "yes", "1"): + SQLALCHEMY_WARN_20 = True + + +def _warn_with_version(msg, version, type_, stacklevel, code=None): + if ( + issubclass(type_, exc.Base20DeprecationWarning) + and not SQLALCHEMY_WARN_20 + ): + return + + warn = type_(msg, code=code) + warn.deprecated_since = version + + _warnings_warn(warn, stacklevel=stacklevel + 1) + + +def warn_deprecated(msg, version, stacklevel=3, code=None): + _warn_with_version( + msg, version, exc.SADeprecationWarning, stacklevel, code=code + ) + + +def warn_deprecated_limited(msg, args, version, stacklevel=3, code=None): + """Issue a deprecation warning with a parameterized string, + limiting the number of registrations. + + """ + if args: + msg = _hash_limit_string(msg, 10, args) + _warn_with_version( + msg, version, exc.SADeprecationWarning, stacklevel, code=code + ) + + +def warn_deprecated_20(msg, stacklevel=3, code=None): + + _warn_with_version( + msg, + exc.RemovedIn20Warning.deprecated_since, + exc.RemovedIn20Warning, + stacklevel, + code=code, + ) + + +def deprecated_cls(version, message, constructor="__init__"): + header = ".. deprecated:: %s %s" % (version, (message or "")) + + def decorate(cls): + return _decorate_cls_with_warning( + cls, + constructor, + exc.SADeprecationWarning, + message % dict(func=constructor), + version, + header, + ) + + return decorate + + +def deprecated_20_cls( + clsname, alternative=None, constructor="__init__", becomes_legacy=False +): + message = ( + ".. deprecated:: 1.4 The %s class is considered legacy as of the " + "1.x series of SQLAlchemy and %s in 2.0." + % ( + clsname, + "will be removed" + if not becomes_legacy + else "becomes a legacy construct", + ) + ) + + if alternative: + message += " " + alternative + + if becomes_legacy: + warning_cls = exc.LegacyAPIWarning + else: + warning_cls = exc.RemovedIn20Warning + + def decorate(cls): + return _decorate_cls_with_warning( + cls, + constructor, + warning_cls, + message, + warning_cls.deprecated_since, + message, + ) + + return decorate + + +def deprecated( + version, + message=None, + add_deprecation_to_docstring=True, + warning=None, + enable_warnings=True, +): + """Decorates a function and issues a deprecation warning on use. + + :param version: + Issue version in the warning. + + :param message: + If provided, issue message in the warning. A sensible default + is used if not provided. + + :param add_deprecation_to_docstring: + Default True. If False, the wrapped function's __doc__ is left + as-is. If True, the 'message' is prepended to the docs if + provided, or sensible default if message is omitted. + + """ + + # nothing is deprecated "since" 2.0 at this time. All "removed in 2.0" + # should emit the RemovedIn20Warning, but messaging should be expressed + # in terms of "deprecated since 1.4". + + if version == "2.0": + if warning is None: + warning = exc.RemovedIn20Warning + version = "1.4" + if add_deprecation_to_docstring: + header = ".. deprecated:: %s %s" % ( + version, + (message or ""), + ) + else: + header = None + + if message is None: + message = "Call to deprecated function %(func)s" + + if warning is None: + warning = exc.SADeprecationWarning + + if warning is not exc.RemovedIn20Warning: + message += " (deprecated since: %s)" % version + + def decorate(fn): + return _decorate_with_warning( + fn, + warning, + message % dict(func=fn.__name__), + version, + header, + enable_warnings=enable_warnings, + ) + + return decorate + + +def moved_20(message, **kw): + return deprecated( + "2.0", message=message, warning=exc.MovedIn20Warning, **kw + ) + + +def deprecated_20(api_name, alternative=None, becomes_legacy=False, **kw): + type_reg = re.match("^:(attr|func|meth):", api_name) + if type_reg: + type_ = {"attr": "attribute", "func": "function", "meth": "method"}[ + type_reg.group(1) + ] + else: + type_ = "construct" + message = ( + "The %s %s is considered legacy as of the " + "1.x series of SQLAlchemy and %s in 2.0." + % ( + api_name, + type_, + "will be removed" + if not becomes_legacy + else "becomes a legacy construct", + ) + ) + + if ":attr:" in api_name: + attribute_ok = kw.pop("warn_on_attribute_access", False) + if not attribute_ok: + assert kw.get("enable_warnings") is False, ( + "attribute %s will emit a warning on read access. " + "If you *really* want this, " + "add warn_on_attribute_access=True. Otherwise please add " + "enable_warnings=False." % api_name + ) + + if alternative: + message += " " + alternative + + if becomes_legacy: + warning_cls = exc.LegacyAPIWarning + else: + warning_cls = exc.RemovedIn20Warning + + return deprecated("2.0", message=message, warning=warning_cls, **kw) + + +def deprecated_params(**specs): + """Decorates a function to warn on use of certain parameters. + + e.g. :: + + @deprecated_params( + weak_identity_map=( + "0.7", + "the :paramref:`.Session.weak_identity_map parameter " + "is deprecated." + ) + + ) + + """ + + messages = {} + versions = {} + version_warnings = {} + + for param, (version, message) in specs.items(): + versions[param] = version + messages[param] = _sanitize_restructured_text(message) + version_warnings[param] = ( + exc.RemovedIn20Warning + if version == "2.0" + else exc.SADeprecationWarning + ) + + def decorate(fn): + spec = compat.inspect_getfullargspec(fn) + + if spec.defaults is not None: + defaults = dict( + zip( + spec.args[(len(spec.args) - len(spec.defaults)) :], + spec.defaults, + ) + ) + check_defaults = set(defaults).intersection(messages) + check_kw = set(messages).difference(defaults) + else: + check_defaults = () + check_kw = set(messages) + + check_any_kw = spec.varkw + + @decorator + def warned(fn, *args, **kwargs): + for m in check_defaults: + if (defaults[m] is None and kwargs[m] is not None) or ( + defaults[m] is not None and kwargs[m] != defaults[m] + ): + _warn_with_version( + messages[m], + versions[m], + version_warnings[m], + stacklevel=3, + ) + + if check_any_kw in messages and set(kwargs).difference( + check_defaults + ): + + _warn_with_version( + messages[check_any_kw], + versions[check_any_kw], + version_warnings[check_any_kw], + stacklevel=3, + ) + + for m in check_kw: + if m in kwargs: + _warn_with_version( + messages[m], + versions[m], + version_warnings[m], + stacklevel=3, + ) + return fn(*args, **kwargs) + + doc = fn.__doc__ is not None and fn.__doc__ or "" + if doc: + doc = inject_param_text( + doc, + { + param: ".. deprecated:: %s %s" + % ("1.4" if version == "2.0" else version, (message or "")) + for param, (version, message) in specs.items() + }, + ) + decorated = warned(fn) + decorated.__doc__ = doc + return decorated + + return decorate + + +def _sanitize_restructured_text(text): + def repl(m): + type_, name = m.group(1, 2) + if type_ in ("func", "meth"): + name += "()" + return name + + text = re.sub(r":ref:`(.+) <.*>`", lambda m: '"%s"' % m.group(1), text) + return re.sub(r"\:(\w+)\:`~?(?:_\w+)?\.?(.+?)`", repl, text) + + +def _decorate_cls_with_warning( + cls, constructor, wtype, message, version, docstring_header=None +): + doc = cls.__doc__ is not None and cls.__doc__ or "" + if docstring_header is not None: + + if constructor is not None: + docstring_header %= dict(func=constructor) + + if issubclass(wtype, exc.Base20DeprecationWarning): + docstring_header += ( + " (Background on SQLAlchemy 2.0 at: " + ":ref:`migration_20_toplevel`)" + ) + doc = inject_docstring_text(doc, docstring_header, 1) + + if type(cls) is type: + clsdict = dict(cls.__dict__) + clsdict["__doc__"] = doc + clsdict.pop("__dict__", None) + clsdict.pop("__weakref__", None) + cls = type(cls.__name__, cls.__bases__, clsdict) + if constructor is not None: + constructor_fn = clsdict[constructor] + + else: + cls.__doc__ = doc + if constructor is not None: + constructor_fn = getattr(cls, constructor) + + if constructor is not None: + setattr( + cls, + constructor, + _decorate_with_warning( + constructor_fn, wtype, message, version, None + ), + ) + return cls + + +def _decorate_with_warning( + func, wtype, message, version, docstring_header=None, enable_warnings=True +): + """Wrap a function with a warnings.warn and augmented docstring.""" + + message = _sanitize_restructured_text(message) + + if issubclass(wtype, exc.Base20DeprecationWarning): + doc_only = ( + " (Background on SQLAlchemy 2.0 at: " + ":ref:`migration_20_toplevel`)" + ) + else: + doc_only = "" + + @decorator + def warned(fn, *args, **kwargs): + skip_warning = not enable_warnings or kwargs.pop( + "_sa_skip_warning", False + ) + if not skip_warning: + _warn_with_version(message, version, wtype, stacklevel=3) + return fn(*args, **kwargs) + + doc = func.__doc__ is not None and func.__doc__ or "" + if docstring_header is not None: + docstring_header %= dict(func=func.__name__) + + docstring_header += doc_only + + doc = inject_docstring_text(doc, docstring_header, 1) + + decorated = warned(func) + decorated.__doc__ = doc + decorated._sa_warn = lambda: _warn_with_version( + message, version, wtype, stacklevel=3 + ) + return decorated diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py new file mode 100644 index 0000000..c3636f0 --- /dev/null +++ b/lib/sqlalchemy/util/langhelpers.py @@ -0,0 +1,1945 @@ +# util/langhelpers.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Routines to help with the creation, loading and introspection of +modules, classes, hierarchies, attributes, functions, and methods. + +""" + +import collections +from functools import update_wrapper +import hashlib +import inspect +import itertools +import operator +import re +import sys +import textwrap +import types +import warnings + +from . import _collections +from . import compat +from .. import exc + + +def md5_hex(x): + if compat.py3k: + x = x.encode("utf-8") + m = hashlib.md5() + m.update(x) + return m.hexdigest() + + +class safe_reraise(object): + """Reraise an exception after invoking some + handler code. + + Stores the existing exception info before + invoking so that it is maintained across a potential + coroutine context switch. + + e.g.:: + + try: + sess.commit() + except: + with safe_reraise(): + sess.rollback() + + """ + + __slots__ = ("warn_only", "_exc_info") + + def __init__(self, warn_only=False): + self.warn_only = warn_only + + def __enter__(self): + self._exc_info = sys.exc_info() + + def __exit__(self, type_, value, traceback): + # see #2703 for notes + if type_ is None: + exc_type, exc_value, exc_tb = self._exc_info + self._exc_info = None # remove potential circular references + if not self.warn_only: + compat.raise_( + exc_value, + with_traceback=exc_tb, + ) + else: + if not compat.py3k and self._exc_info and self._exc_info[1]: + # emulate Py3K's behavior of telling us when an exception + # occurs in an exception handler. + warn( + "An exception has occurred during handling of a " + "previous exception. The previous exception " + "is:\n %s %s\n" % (self._exc_info[0], self._exc_info[1]) + ) + self._exc_info = None # remove potential circular references + compat.raise_(value, with_traceback=traceback) + + +def walk_subclasses(cls): + seen = set() + + stack = [cls] + while stack: + cls = stack.pop() + if cls in seen: + continue + else: + seen.add(cls) + stack.extend(cls.__subclasses__()) + yield cls + + +def string_or_unprintable(element): + if isinstance(element, compat.string_types): + return element + else: + try: + return str(element) + except Exception: + return "unprintable element %r" % element + + +def clsname_as_plain_name(cls): + return " ".join( + n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__) + ) + + +def method_is_overridden(instance_or_cls, against_method): + """Return True if the two class methods don't match.""" + + if not isinstance(instance_or_cls, type): + current_cls = instance_or_cls.__class__ + else: + current_cls = instance_or_cls + + method_name = against_method.__name__ + + current_method = getattr(current_cls, method_name) + + return current_method != against_method + + +def decode_slice(slc): + """decode a slice object as sent to __getitem__. + + takes into account the 2.5 __index__() method, basically. + + """ + ret = [] + for x in slc.start, slc.stop, slc.step: + if hasattr(x, "__index__"): + x = x.__index__() + ret.append(x) + return tuple(ret) + + +def _unique_symbols(used, *bases): + used = set(used) + for base in bases: + pool = itertools.chain( + (base,), + compat.itertools_imap(lambda i: base + str(i), range(1000)), + ) + for sym in pool: + if sym not in used: + used.add(sym) + yield sym + break + else: + raise NameError("exhausted namespace for symbol base %s" % base) + + +def map_bits(fn, n): + """Call the given function given each nonzero bit from n.""" + + while n: + b = n & (~n + 1) + yield fn(b) + n ^= b + + +def decorator(target): + """A signature-matching decorator factory.""" + + def decorate(fn): + if not inspect.isfunction(fn) and not inspect.ismethod(fn): + raise Exception("not a decoratable function") + + spec = compat.inspect_getfullargspec(fn) + env = {} + + spec = _update_argspec_defaults_into_env(spec, env) + + names = tuple(spec[0]) + spec[1:3] + (fn.__name__,) + targ_name, fn_name = _unique_symbols(names, "target", "fn") + + metadata = dict(target=targ_name, fn=fn_name) + metadata.update(format_argspec_plus(spec, grouped=False)) + metadata["name"] = fn.__name__ + code = ( + """\ +def %(name)s(%(args)s): + return %(target)s(%(fn)s, %(apply_kw)s) +""" + % metadata + ) + env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__}) + + decorated = _exec_code_in_env(code, env, fn.__name__) + decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ + decorated.__wrapped__ = fn + return update_wrapper(decorated, fn) + + return update_wrapper(decorate, target) + + +def _update_argspec_defaults_into_env(spec, env): + """given a FullArgSpec, convert defaults to be symbol names in an env.""" + + if spec.defaults: + new_defaults = [] + i = 0 + for arg in spec.defaults: + if type(arg).__module__ not in ("builtins", "__builtin__"): + name = "x%d" % i + env[name] = arg + new_defaults.append(name) + i += 1 + else: + new_defaults.append(arg) + elem = list(spec) + elem[3] = tuple(new_defaults) + return compat.FullArgSpec(*elem) + else: + return spec + + +def _exec_code_in_env(code, env, fn_name): + exec(code, env) + return env[fn_name] + + +def public_factory(target, location, class_location=None): + """Produce a wrapping function for the given cls or classmethod. + + Rationale here is so that the __init__ method of the + class can serve as documentation for the function. + + """ + + if isinstance(target, type): + fn = target.__init__ + callable_ = target + doc = ( + "Construct a new :class:`%s` object. \n\n" + "This constructor is mirrored as a public API function; " + "see :func:`sqlalchemy%s` " + "for a full usage and argument description." + % ( + class_location if class_location else ".%s" % target.__name__, + location, + ) + ) + else: + fn = callable_ = target + doc = ( + "This function is mirrored; see :func:`sqlalchemy%s` " + "for a description of arguments." % location + ) + + location_name = location.split(".")[-1] + spec = compat.inspect_getfullargspec(fn) + del spec[0][0] + metadata = format_argspec_plus(spec, grouped=False) + metadata["name"] = location_name + code = ( + """\ +def %(name)s(%(args)s): + return cls(%(apply_kw)s) +""" + % metadata + ) + env = { + "cls": callable_, + "symbol": symbol, + "__name__": callable_.__module__, + } + exec(code, env) + decorated = env[location_name] + + if hasattr(fn, "_linked_to"): + linked_to, linked_to_location = fn._linked_to + linked_to_doc = linked_to.__doc__ + if class_location is None: + class_location = "%s.%s" % (target.__module__, target.__name__) + + linked_to_doc = inject_docstring_text( + linked_to_doc, + ".. container:: inherited_member\n\n " + "This documentation is inherited from :func:`sqlalchemy%s`; " + "this constructor, :func:`sqlalchemy%s`, " + "creates a :class:`sqlalchemy%s` object. See that class for " + "additional details describing this subclass." + % (linked_to_location, location, class_location), + 1, + ) + decorated.__doc__ = linked_to_doc + else: + decorated.__doc__ = fn.__doc__ + + decorated.__module__ = "sqlalchemy" + location.rsplit(".", 1)[0] + if decorated.__module__ not in sys.modules: + raise ImportError( + "public_factory location %s is not in sys.modules" + % (decorated.__module__,) + ) + + if compat.py2k or hasattr(fn, "__func__"): + fn.__func__.__doc__ = doc + if not hasattr(fn.__func__, "_linked_to"): + fn.__func__._linked_to = (decorated, location) + else: + fn.__doc__ = doc + if not hasattr(fn, "_linked_to"): + fn._linked_to = (decorated, location) + + return decorated + + +class PluginLoader(object): + def __init__(self, group, auto_fn=None): + self.group = group + self.impls = {} + self.auto_fn = auto_fn + + def clear(self): + self.impls.clear() + + def load(self, name): + if name in self.impls: + return self.impls[name]() + + if self.auto_fn: + loader = self.auto_fn(name) + if loader: + self.impls[name] = loader + return loader() + + for impl in compat.importlib_metadata_get(self.group): + if impl.name == name: + self.impls[name] = impl.load + return impl.load() + + raise exc.NoSuchModuleError( + "Can't load plugin: %s:%s" % (self.group, name) + ) + + def register(self, name, modulepath, objname): + def load(): + mod = compat.import_(modulepath) + for token in modulepath.split(".")[1:]: + mod = getattr(mod, token) + return getattr(mod, objname) + + self.impls[name] = load + + +def _inspect_func_args(fn): + try: + co_varkeywords = inspect.CO_VARKEYWORDS + except AttributeError: + # https://docs.python.org/3/library/inspect.html + # The flags are specific to CPython, and may not be defined in other + # Python implementations. Furthermore, the flags are an implementation + # detail, and can be removed or deprecated in future Python releases. + spec = compat.inspect_getfullargspec(fn) + return spec[0], bool(spec[2]) + else: + # use fn.__code__ plus flags to reduce method call overhead + co = fn.__code__ + nargs = co.co_argcount + return ( + list(co.co_varnames[:nargs]), + bool(co.co_flags & co_varkeywords), + ) + + +def get_cls_kwargs(cls, _set=None): + r"""Return the full set of inherited kwargs for the given `cls`. + + Probes a class's __init__ method, collecting all named arguments. If the + __init__ defines a \**kwargs catch-all, then the constructor is presumed + to pass along unrecognized keywords to its base classes, and the + collection process is repeated recursively on each of the bases. + + Uses a subset of inspect.getfullargspec() to cut down on method overhead, + as this is used within the Core typing system to create copies of type + objects which is a performance-sensitive operation. + + No anonymous tuple arguments please ! + + """ + toplevel = _set is None + if toplevel: + _set = set() + + ctr = cls.__dict__.get("__init__", False) + + has_init = ( + ctr + and isinstance(ctr, types.FunctionType) + and isinstance(ctr.__code__, types.CodeType) + ) + + if has_init: + names, has_kw = _inspect_func_args(ctr) + _set.update(names) + + if not has_kw and not toplevel: + return None + + if not has_init or has_kw: + for c in cls.__bases__: + if get_cls_kwargs(c, _set) is None: + break + + _set.discard("self") + return _set + + +def get_func_kwargs(func): + """Return the set of legal kwargs for the given `func`. + + Uses getargspec so is safe to call for methods, functions, + etc. + + """ + + return compat.inspect_getfullargspec(func)[0] + + +def get_callable_argspec(fn, no_self=False, _is_init=False): + """Return the argument signature for any callable. + + All pure-Python callables are accepted, including + functions, methods, classes, objects with __call__; + builtins and other edge cases like functools.partial() objects + raise a TypeError. + + """ + if inspect.isbuiltin(fn): + raise TypeError("Can't inspect builtin: %s" % fn) + elif inspect.isfunction(fn): + if _is_init and no_self: + spec = compat.inspect_getfullargspec(fn) + return compat.FullArgSpec( + spec.args[1:], + spec.varargs, + spec.varkw, + spec.defaults, + spec.kwonlyargs, + spec.kwonlydefaults, + spec.annotations, + ) + else: + return compat.inspect_getfullargspec(fn) + elif inspect.ismethod(fn): + if no_self and (_is_init or fn.__self__): + spec = compat.inspect_getfullargspec(fn.__func__) + return compat.FullArgSpec( + spec.args[1:], + spec.varargs, + spec.varkw, + spec.defaults, + spec.kwonlyargs, + spec.kwonlydefaults, + spec.annotations, + ) + else: + return compat.inspect_getfullargspec(fn.__func__) + elif inspect.isclass(fn): + return get_callable_argspec( + fn.__init__, no_self=no_self, _is_init=True + ) + elif hasattr(fn, "__func__"): + return compat.inspect_getfullargspec(fn.__func__) + elif hasattr(fn, "__call__"): + if inspect.ismethod(fn.__call__): + return get_callable_argspec(fn.__call__, no_self=no_self) + else: + raise TypeError("Can't inspect callable: %s" % fn) + else: + raise TypeError("Can't inspect callable: %s" % fn) + + +def format_argspec_plus(fn, grouped=True): + """Returns a dictionary of formatted, introspected function arguments. + + A enhanced variant of inspect.formatargspec to support code generation. + + fn + An inspectable callable or tuple of inspect getargspec() results. + grouped + Defaults to True; include (parens, around, argument) lists + + Returns: + + args + Full inspect.formatargspec for fn + self_arg + The name of the first positional argument, varargs[0], or None + if the function defines no positional arguments. + apply_pos + args, re-written in calling rather than receiving syntax. Arguments are + passed positionally. + apply_kw + Like apply_pos, except keyword-ish args are passed as keywords. + apply_pos_proxied + Like apply_pos but omits the self/cls argument + + Example:: + + >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123) + {'args': '(self, a, b, c=3, **d)', + 'self_arg': 'self', + 'apply_kw': '(self, a, b, c=c, **d)', + 'apply_pos': '(self, a, b, c, **d)'} + + """ + if compat.callable(fn): + spec = compat.inspect_getfullargspec(fn) + else: + spec = fn + + args = compat.inspect_formatargspec(*spec) + + apply_pos = compat.inspect_formatargspec( + spec[0], spec[1], spec[2], None, spec[4] + ) + + if spec[0]: + self_arg = spec[0][0] + + apply_pos_proxied = compat.inspect_formatargspec( + spec[0][1:], spec[1], spec[2], None, spec[4] + ) + + elif spec[1]: + # I'm not sure what this is + self_arg = "%s[0]" % spec[1] + + apply_pos_proxied = apply_pos + else: + self_arg = None + apply_pos_proxied = apply_pos + + num_defaults = 0 + if spec[3]: + num_defaults += len(spec[3]) + if spec[4]: + num_defaults += len(spec[4]) + name_args = spec[0] + spec[4] + + if num_defaults: + defaulted_vals = name_args[0 - num_defaults :] + else: + defaulted_vals = () + + apply_kw = compat.inspect_formatargspec( + name_args, + spec[1], + spec[2], + defaulted_vals, + formatvalue=lambda x: "=" + x, + ) + + if spec[0]: + apply_kw_proxied = compat.inspect_formatargspec( + name_args[1:], + spec[1], + spec[2], + defaulted_vals, + formatvalue=lambda x: "=" + x, + ) + else: + apply_kw_proxied = apply_kw + + if grouped: + return dict( + args=args, + self_arg=self_arg, + apply_pos=apply_pos, + apply_kw=apply_kw, + apply_pos_proxied=apply_pos_proxied, + apply_kw_proxied=apply_kw_proxied, + ) + else: + return dict( + args=args[1:-1], + self_arg=self_arg, + apply_pos=apply_pos[1:-1], + apply_kw=apply_kw[1:-1], + apply_pos_proxied=apply_pos_proxied[1:-1], + apply_kw_proxied=apply_kw_proxied[1:-1], + ) + + +def format_argspec_init(method, grouped=True): + """format_argspec_plus with considerations for typical __init__ methods + + Wraps format_argspec_plus with error handling strategies for typical + __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + if method is object.__init__: + args = "(self)" if grouped else "self" + proxied = "()" if grouped else "" + else: + try: + return format_argspec_plus(method, grouped=grouped) + except TypeError: + args = ( + "(self, *args, **kwargs)" + if grouped + else "self, *args, **kwargs" + ) + proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs" + return dict( + self_arg="self", + args=args, + apply_pos=args, + apply_kw=args, + apply_pos_proxied=proxied, + apply_kw_proxied=proxied, + ) + + +def create_proxy_methods( + target_cls, + target_cls_sphinx_name, + proxy_cls_sphinx_name, + classmethods=(), + methods=(), + attributes=(), +): + """A class decorator that will copy attributes to a proxy class. + + The class to be instrumented must define a single accessor "_proxied". + + """ + + def decorate(cls): + def instrument(name, clslevel=False): + fn = getattr(target_cls, name) + spec = compat.inspect_getfullargspec(fn) + env = {"__name__": fn.__module__} + + spec = _update_argspec_defaults_into_env(spec, env) + caller_argspec = format_argspec_plus(spec, grouped=False) + + metadata = { + "name": fn.__name__, + "apply_pos_proxied": caller_argspec["apply_pos_proxied"], + "apply_kw_proxied": caller_argspec["apply_kw_proxied"], + "args": caller_argspec["args"], + "self_arg": caller_argspec["self_arg"], + } + + if clslevel: + code = ( + "def %(name)s(%(args)s):\n" + " return target_cls.%(name)s(%(apply_kw_proxied)s)" + % metadata + ) + env["target_cls"] = target_cls + else: + code = ( + "def %(name)s(%(args)s):\n" + " return %(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)" # noqa: E501 + % metadata + ) + + proxy_fn = _exec_code_in_env(code, env, fn.__name__) + proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__ + proxy_fn.__doc__ = inject_docstring_text( + fn.__doc__, + ".. container:: class_bases\n\n " + "Proxied for the %s class on behalf of the %s class." + % (target_cls_sphinx_name, proxy_cls_sphinx_name), + 1, + ) + + if clslevel: + proxy_fn = classmethod(proxy_fn) + + return proxy_fn + + def makeprop(name): + attr = target_cls.__dict__.get(name, None) + + if attr is not None: + doc = inject_docstring_text( + attr.__doc__, + ".. container:: class_bases\n\n " + "Proxied for the %s class on behalf of the %s class." + % ( + target_cls_sphinx_name, + proxy_cls_sphinx_name, + ), + 1, + ) + else: + doc = None + + code = ( + "def set_(self, attr):\n" + " self._proxied.%(name)s = attr\n" + "def get(self):\n" + " return self._proxied.%(name)s\n" + "get.__doc__ = doc\n" + "getset = property(get, set_)" + ) % {"name": name} + + getset = _exec_code_in_env(code, {"doc": doc}, "getset") + + return getset + + for meth in methods: + if hasattr(cls, meth): + raise TypeError( + "class %s already has a method %s" % (cls, meth) + ) + setattr(cls, meth, instrument(meth)) + + for prop in attributes: + if hasattr(cls, prop): + raise TypeError( + "class %s already has a method %s" % (cls, prop) + ) + setattr(cls, prop, makeprop(prop)) + + for prop in classmethods: + if hasattr(cls, prop): + raise TypeError( + "class %s already has a method %s" % (cls, prop) + ) + setattr(cls, prop, instrument(prop, clslevel=True)) + + return cls + + return decorate + + +def getargspec_init(method): + """inspect.getargspec with considerations for typical __init__ methods + + Wraps inspect.getargspec with error handling for typical __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + try: + return compat.inspect_getfullargspec(method) + except TypeError: + if method is object.__init__: + return (["self"], None, None, None) + else: + return (["self"], "args", "kwargs", None) + + +def unbound_method_to_callable(func_or_cls): + """Adjust the incoming callable such that a 'self' argument is not + required. + + """ + + if isinstance(func_or_cls, types.MethodType) and not func_or_cls.__self__: + return func_or_cls.__func__ + else: + return func_or_cls + + +def generic_repr(obj, additional_kw=(), to_inspect=None, omit_kwarg=()): + """Produce a __repr__() based on direct association of the __init__() + specification vs. same-named attributes present. + + """ + if to_inspect is None: + to_inspect = [obj] + else: + to_inspect = _collections.to_list(to_inspect) + + missing = object() + + pos_args = [] + kw_args = _collections.OrderedDict() + vargs = None + for i, insp in enumerate(to_inspect): + try: + spec = compat.inspect_getfullargspec(insp.__init__) + except TypeError: + continue + else: + default_len = spec.defaults and len(spec.defaults) or 0 + if i == 0: + if spec.varargs: + vargs = spec.varargs + if default_len: + pos_args.extend(spec.args[1:-default_len]) + else: + pos_args.extend(spec.args[1:]) + else: + kw_args.update( + [(arg, missing) for arg in spec.args[1:-default_len]] + ) + + if default_len: + kw_args.update( + [ + (arg, default) + for arg, default in zip( + spec.args[-default_len:], spec.defaults + ) + ] + ) + output = [] + + output.extend(repr(getattr(obj, arg, None)) for arg in pos_args) + + if vargs is not None and hasattr(obj, vargs): + output.extend([repr(val) for val in getattr(obj, vargs)]) + + for arg, defval in kw_args.items(): + if arg in omit_kwarg: + continue + try: + val = getattr(obj, arg, missing) + if val is not missing and val != defval: + output.append("%s=%r" % (arg, val)) + except Exception: + pass + + if additional_kw: + for arg, defval in additional_kw: + try: + val = getattr(obj, arg, missing) + if val is not missing and val != defval: + output.append("%s=%r" % (arg, val)) + except Exception: + pass + + return "%s(%s)" % (obj.__class__.__name__, ", ".join(output)) + + +class portable_instancemethod(object): + """Turn an instancemethod into a (parent, name) pair + to produce a serializable callable. + + """ + + __slots__ = "target", "name", "kwargs", "__weakref__" + + def __getstate__(self): + return { + "target": self.target, + "name": self.name, + "kwargs": self.kwargs, + } + + def __setstate__(self, state): + self.target = state["target"] + self.name = state["name"] + self.kwargs = state.get("kwargs", ()) + + def __init__(self, meth, kwargs=()): + self.target = meth.__self__ + self.name = meth.__name__ + self.kwargs = kwargs + + def __call__(self, *arg, **kw): + kw.update(self.kwargs) + return getattr(self.target, self.name)(*arg, **kw) + + +def class_hierarchy(cls): + """Return an unordered sequence of all classes related to cls. + + Traverses diamond hierarchies. + + Fibs slightly: subclasses of builtin types are not returned. Thus + class_hierarchy(class A(object)) returns (A, object), not A plus every + class systemwide that derives from object. + + Old-style classes are discarded and hierarchies rooted on them + will not be descended. + + """ + if compat.py2k: + if isinstance(cls, types.ClassType): + return list() + + hier = {cls} + process = list(cls.__mro__) + while process: + c = process.pop() + if compat.py2k: + if isinstance(c, types.ClassType): + continue + bases = ( + _ + for _ in c.__bases__ + if _ not in hier and not isinstance(_, types.ClassType) + ) + else: + bases = (_ for _ in c.__bases__ if _ not in hier) + + for b in bases: + process.append(b) + hier.add(b) + + if compat.py3k: + if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"): + continue + else: + if c.__module__ == "__builtin__" or not hasattr( + c, "__subclasses__" + ): + continue + + for s in [_ for _ in c.__subclasses__() if _ not in hier]: + process.append(s) + hier.add(s) + return list(hier) + + +def iterate_attributes(cls): + """iterate all the keys and attributes associated + with a class, without using getattr(). + + Does not use getattr() so that class-sensitive + descriptors (i.e. property.__get__()) are not called. + + """ + keys = dir(cls) + for key in keys: + for c in cls.__mro__: + if key in c.__dict__: + yield (key, c.__dict__[key]) + break + + +def monkeypatch_proxied_specials( + into_cls, + from_cls, + skip=None, + only=None, + name="self.proxy", + from_instance=None, +): + """Automates delegation of __specials__ for a proxying type.""" + + if only: + dunders = only + else: + if skip is None: + skip = ( + "__slots__", + "__del__", + "__getattribute__", + "__metaclass__", + "__getstate__", + "__setstate__", + ) + dunders = [ + m + for m in dir(from_cls) + if ( + m.startswith("__") + and m.endswith("__") + and not hasattr(into_cls, m) + and m not in skip + ) + ] + + for method in dunders: + try: + fn = getattr(from_cls, method) + if not hasattr(fn, "__call__"): + continue + fn = getattr(fn, "__func__", fn) + except AttributeError: + continue + try: + spec = compat.inspect_getfullargspec(fn) + fn_args = compat.inspect_formatargspec(spec[0]) + d_args = compat.inspect_formatargspec(spec[0][1:]) + except TypeError: + fn_args = "(self, *args, **kw)" + d_args = "(*args, **kw)" + + py = ( + "def %(method)s%(fn_args)s: " + "return %(name)s.%(method)s%(d_args)s" % locals() + ) + + env = from_instance is not None and {name: from_instance} or {} + compat.exec_(py, env) + try: + env[method].__defaults__ = fn.__defaults__ + except AttributeError: + pass + setattr(into_cls, method, env[method]) + + +def methods_equivalent(meth1, meth2): + """Return True if the two methods are the same implementation.""" + + return getattr(meth1, "__func__", meth1) is getattr( + meth2, "__func__", meth2 + ) + + +def as_interface(obj, cls=None, methods=None, required=None): + """Ensure basic interface compliance for an instance or dict of callables. + + Checks that ``obj`` implements public methods of ``cls`` or has members + listed in ``methods``. If ``required`` is not supplied, implementing at + least one interface method is sufficient. Methods present on ``obj`` that + are not in the interface are ignored. + + If ``obj`` is a dict and ``dict`` does not meet the interface + requirements, the keys of the dictionary are inspected. Keys present in + ``obj`` that are not in the interface will raise TypeErrors. + + Raises TypeError if ``obj`` does not meet the interface criteria. + + In all passing cases, an object with callable members is returned. In the + simple case, ``obj`` is returned as-is; if dict processing kicks in then + an anonymous class is returned. + + obj + A type, instance, or dictionary of callables. + cls + Optional, a type. All public methods of cls are considered the + interface. An ``obj`` instance of cls will always pass, ignoring + ``required``.. + methods + Optional, a sequence of method names to consider as the interface. + required + Optional, a sequence of mandatory implementations. If omitted, an + ``obj`` that provides at least one interface method is considered + sufficient. As a convenience, required may be a type, in which case + all public methods of the type are required. + + """ + if not cls and not methods: + raise TypeError("a class or collection of method names are required") + + if isinstance(cls, type) and isinstance(obj, cls): + return obj + + interface = set(methods or [m for m in dir(cls) if not m.startswith("_")]) + implemented = set(dir(obj)) + + complies = operator.ge + if isinstance(required, type): + required = interface + elif not required: + required = set() + complies = operator.gt + else: + required = set(required) + + if complies(implemented.intersection(interface), required): + return obj + + # No dict duck typing here. + if not isinstance(obj, dict): + qualifier = complies is operator.gt and "any of" or "all of" + raise TypeError( + "%r does not implement %s: %s" + % (obj, qualifier, ", ".join(interface)) + ) + + class AnonymousInterface(object): + """A callable-holding shell.""" + + if cls: + AnonymousInterface.__name__ = "Anonymous" + cls.__name__ + found = set() + + for method, impl in dictlike_iteritems(obj): + if method not in interface: + raise TypeError("%r: unknown in this interface" % method) + if not compat.callable(impl): + raise TypeError("%r=%r is not callable" % (method, impl)) + setattr(AnonymousInterface, method, staticmethod(impl)) + found.add(method) + + if complies(found, required): + return AnonymousInterface + + raise TypeError( + "dictionary does not contain required keys %s" + % ", ".join(required - found) + ) + + +class memoized_property(object): + """A read-only @property that is only evaluated once.""" + + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return self + obj.__dict__[self.__name__] = result = self.fget(obj) + return result + + def _reset(self, obj): + memoized_property.reset(obj, self.__name__) + + @classmethod + def reset(cls, obj, name): + obj.__dict__.pop(name, None) + + +def memoized_instancemethod(fn): + """Decorate a method memoize its return value. + + Best applied to no-arg methods: memoization is not sensitive to + argument values, and will always return the same value even when + called with different arguments. + + """ + + def oneshot(self, *args, **kw): + result = fn(self, *args, **kw) + + def memo(*a, **kw): + return result + + memo.__name__ = fn.__name__ + memo.__doc__ = fn.__doc__ + self.__dict__[fn.__name__] = memo + return result + + return update_wrapper(oneshot, fn) + + +class HasMemoized(object): + """A class that maintains the names of memoized elements in a + collection for easy cache clearing, generative, etc. + + """ + + __slots__ = () + + _memoized_keys = frozenset() + + def _reset_memoizations(self): + for elem in self._memoized_keys: + self.__dict__.pop(elem, None) + + def _assert_no_memoizations(self): + for elem in self._memoized_keys: + assert elem not in self.__dict__ + + def _set_memoized_attribute(self, key, value): + self.__dict__[key] = value + self._memoized_keys |= {key} + + class memoized_attribute(object): + """A read-only @property that is only evaluated once. + + :meta private: + + """ + + def __init__(self, fget, doc=None): + self.fget = fget + self.__doc__ = doc or fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, cls): + if obj is None: + return self + obj.__dict__[self.__name__] = result = self.fget(obj) + obj._memoized_keys |= {self.__name__} + return result + + @classmethod + def memoized_instancemethod(cls, fn): + """Decorate a method memoize its return value.""" + + def oneshot(self, *args, **kw): + result = fn(self, *args, **kw) + + def memo(*a, **kw): + return result + + memo.__name__ = fn.__name__ + memo.__doc__ = fn.__doc__ + self.__dict__[fn.__name__] = memo + self._memoized_keys |= {fn.__name__} + return result + + return update_wrapper(oneshot, fn) + + +class MemoizedSlots(object): + """Apply memoized items to an object using a __getattr__ scheme. + + This allows the functionality of memoized_property and + memoized_instancemethod to be available to a class using __slots__. + + """ + + __slots__ = () + + def _fallback_getattr(self, key): + raise AttributeError(key) + + def __getattr__(self, key): + if key.startswith("_memoized"): + raise AttributeError(key) + elif hasattr(self, "_memoized_attr_%s" % key): + value = getattr(self, "_memoized_attr_%s" % key)() + setattr(self, key, value) + return value + elif hasattr(self, "_memoized_method_%s" % key): + fn = getattr(self, "_memoized_method_%s" % key) + + def oneshot(*args, **kw): + result = fn(*args, **kw) + + def memo(*a, **kw): + return result + + memo.__name__ = fn.__name__ + memo.__doc__ = fn.__doc__ + setattr(self, key, memo) + return result + + oneshot.__doc__ = fn.__doc__ + return oneshot + else: + return self._fallback_getattr(key) + + +# from paste.deploy.converters +def asbool(obj): + if isinstance(obj, compat.string_types): + obj = obj.strip().lower() + if obj in ["true", "yes", "on", "y", "t", "1"]: + return True + elif obj in ["false", "no", "off", "n", "f", "0"]: + return False + else: + raise ValueError("String is not true/false: %r" % obj) + return bool(obj) + + +def bool_or_str(*text): + """Return a callable that will evaluate a string as + boolean, or one of a set of "alternate" string values. + + """ + + def bool_or_value(obj): + if obj in text: + return obj + else: + return asbool(obj) + + return bool_or_value + + +def asint(value): + """Coerce to integer.""" + + if value is None: + return value + return int(value) + + +def coerce_kw_type(kw, key, type_, flexi_bool=True, dest=None): + r"""If 'key' is present in dict 'kw', coerce its value to type 'type\_' if + necessary. If 'flexi_bool' is True, the string '0' is considered false + when coercing to boolean. + """ + + if dest is None: + dest = kw + + if ( + key in kw + and (not isinstance(type_, type) or not isinstance(kw[key], type_)) + and kw[key] is not None + ): + if type_ is bool and flexi_bool: + dest[key] = asbool(kw[key]) + else: + dest[key] = type_(kw[key]) + + +def constructor_key(obj, cls): + """Produce a tuple structure that is cacheable using the __dict__ of + obj to retrieve values + + """ + names = get_cls_kwargs(cls) + return (cls,) + tuple( + (k, obj.__dict__[k]) for k in names if k in obj.__dict__ + ) + + +def constructor_copy(obj, cls, *args, **kw): + """Instantiate cls using the __dict__ of obj as constructor arguments. + + Uses inspect to match the named arguments of ``cls``. + + """ + + names = get_cls_kwargs(cls) + kw.update( + (k, obj.__dict__[k]) for k in names.difference(kw) if k in obj.__dict__ + ) + return cls(*args, **kw) + + +def counter(): + """Return a threadsafe counter function.""" + + lock = compat.threading.Lock() + counter = itertools.count(1) + + # avoid the 2to3 "next" transformation... + def _next(): + with lock: + return next(counter) + + return _next + + +def duck_type_collection(specimen, default=None): + """Given an instance or class, guess if it is or is acting as one of + the basic collection types: list, set and dict. If the __emulates__ + property is present, return that preferentially. + """ + + if hasattr(specimen, "__emulates__"): + # canonicalize set vs sets.Set to a standard: the builtin set + if specimen.__emulates__ is not None and issubclass( + specimen.__emulates__, set + ): + return set + else: + return specimen.__emulates__ + + isa = isinstance(specimen, type) and issubclass or isinstance + if isa(specimen, list): + return list + elif isa(specimen, set): + return set + elif isa(specimen, dict): + return dict + + if hasattr(specimen, "append"): + return list + elif hasattr(specimen, "add"): + return set + elif hasattr(specimen, "set"): + return dict + else: + return default + + +def assert_arg_type(arg, argtype, name): + if isinstance(arg, argtype): + return arg + else: + if isinstance(argtype, tuple): + raise exc.ArgumentError( + "Argument '%s' is expected to be one of type %s, got '%s'" + % (name, " or ".join("'%s'" % a for a in argtype), type(arg)) + ) + else: + raise exc.ArgumentError( + "Argument '%s' is expected to be of type '%s', got '%s'" + % (name, argtype, type(arg)) + ) + + +def dictlike_iteritems(dictlike): + """Return a (key, value) iterator for almost any dict-like object.""" + + if compat.py3k: + if hasattr(dictlike, "items"): + return list(dictlike.items()) + else: + if hasattr(dictlike, "iteritems"): + return dictlike.iteritems() + elif hasattr(dictlike, "items"): + return iter(dictlike.items()) + + getter = getattr(dictlike, "__getitem__", getattr(dictlike, "get", None)) + if getter is None: + raise TypeError("Object '%r' is not dict-like" % dictlike) + + if hasattr(dictlike, "iterkeys"): + + def iterator(): + for key in dictlike.iterkeys(): + yield key, getter(key) + + return iterator() + elif hasattr(dictlike, "keys"): + return iter((key, getter(key)) for key in dictlike.keys()) + else: + raise TypeError("Object '%r' is not dict-like" % dictlike) + + +class classproperty(property): + """A decorator that behaves like @property except that operates + on classes rather than instances. + + The decorator is currently special when using the declarative + module, but note that the + :class:`~.sqlalchemy.ext.declarative.declared_attr` + decorator should be used for this purpose with declarative. + + """ + + def __init__(self, fget, *arg, **kw): + super(classproperty, self).__init__(fget, *arg, **kw) + self.__doc__ = fget.__doc__ + + def __get__(desc, self, cls): + return desc.fget(cls) + + +class hybridproperty(object): + def __init__(self, func): + self.func = func + self.clslevel = func + + def __get__(self, instance, owner): + if instance is None: + clsval = self.clslevel(owner) + return clsval + else: + return self.func(instance) + + def classlevel(self, func): + self.clslevel = func + return self + + +class hybridmethod(object): + """Decorate a function as cls- or instance- level.""" + + def __init__(self, func): + self.func = self.__func__ = func + self.clslevel = func + + def __get__(self, instance, owner): + if instance is None: + return self.clslevel.__get__(owner, owner.__class__) + else: + return self.func.__get__(instance, owner) + + def classlevel(self, func): + self.clslevel = func + return self + + +class _symbol(int): + def __new__(self, name, doc=None, canonical=None): + """Construct a new named symbol.""" + assert isinstance(name, compat.string_types) + if canonical is None: + canonical = hash(name) + v = int.__new__(_symbol, canonical) + v.name = name + if doc: + v.__doc__ = doc + return v + + def __reduce__(self): + return symbol, (self.name, "x", int(self)) + + def __str__(self): + return repr(self) + + def __repr__(self): + return "symbol(%r)" % self.name + + +_symbol.__name__ = "symbol" + + +class symbol(object): + """A constant symbol. + + >>> symbol('foo') is symbol('foo') + True + >>> symbol('foo') + <symbol 'foo> + + A slight refinement of the MAGICCOOKIE=object() pattern. The primary + advantage of symbol() is its repr(). They are also singletons. + + Repeated calls of symbol('name') will all return the same instance. + + The optional ``doc`` argument assigns to ``__doc__``. This + is strictly so that Sphinx autoattr picks up the docstring we want + (it doesn't appear to pick up the in-module docstring if the datamember + is in a different module - autoattribute also blows up completely). + If Sphinx fixes/improves this then we would no longer need + ``doc`` here. + + """ + + symbols = {} + _lock = compat.threading.Lock() + + def __new__(cls, name, doc=None, canonical=None): + with cls._lock: + sym = cls.symbols.get(name) + if sym is None: + cls.symbols[name] = sym = _symbol(name, doc, canonical) + return sym + + @classmethod + def parse_user_argument( + cls, arg, choices, name, resolve_symbol_names=False + ): + """Given a user parameter, parse the parameter into a chosen symbol. + + The user argument can be a string name that matches the name of a + symbol, or the symbol object itself, or any number of alternate choices + such as True/False/ None etc. + + :param arg: the user argument. + :param choices: dictionary of symbol object to list of possible + entries. + :param name: name of the argument. Used in an :class:`.ArgumentError` + that is raised if the parameter doesn't match any available argument. + :param resolve_symbol_names: include the name of each symbol as a valid + entry. + + """ + # note using hash lookup is tricky here because symbol's `__hash__` + # is its int value which we don't want included in the lookup + # explicitly, so we iterate and compare each. + for sym, choice in choices.items(): + if arg is sym: + return sym + elif resolve_symbol_names and arg == sym.name: + return sym + elif arg in choice: + return sym + + if arg is None: + return None + + raise exc.ArgumentError("Invalid value for '%s': %r" % (name, arg)) + + +_creation_order = 1 + + +def set_creation_order(instance): + """Assign a '_creation_order' sequence to the given instance. + + This allows multiple instances to be sorted in order of creation + (typically within a single thread; the counter is not particularly + threadsafe). + + """ + global _creation_order + instance._creation_order = _creation_order + _creation_order += 1 + + +def warn_exception(func, *args, **kwargs): + """executes the given function, catches all exceptions and converts to + a warning. + + """ + try: + return func(*args, **kwargs) + except Exception: + warn("%s('%s') ignored" % sys.exc_info()[0:2]) + + +def ellipses_string(value, len_=25): + try: + if len(value) > len_: + return "%s..." % value[0:len_] + else: + return value + except TypeError: + return value + + +class _hash_limit_string(compat.text_type): + """A string subclass that can only be hashed on a maximum amount + of unique values. + + This is used for warnings so that we can send out parameterized warnings + without the __warningregistry__ of the module, or the non-overridable + "once" registry within warnings.py, overloading memory, + + + """ + + def __new__(cls, value, num, args): + interpolated = (value % args) + ( + " (this warning may be suppressed after %d occurrences)" % num + ) + self = super(_hash_limit_string, cls).__new__(cls, interpolated) + self._hash = hash("%s_%d" % (value, hash(interpolated) % num)) + return self + + def __hash__(self): + return self._hash + + def __eq__(self, other): + return hash(self) == hash(other) + + +def warn(msg, code=None): + """Issue a warning. + + If msg is a string, :class:`.exc.SAWarning` is used as + the category. + + """ + if code: + _warnings_warn(exc.SAWarning(msg, code=code)) + else: + _warnings_warn(msg, exc.SAWarning) + + +def warn_limited(msg, args): + """Issue a warning with a parameterized string, limiting the number + of registrations. + + """ + if args: + msg = _hash_limit_string(msg, 10, args) + _warnings_warn(msg, exc.SAWarning) + + +def _warnings_warn(message, category=None, stacklevel=2): + + # adjust the given stacklevel to be outside of SQLAlchemy + try: + frame = sys._getframe(stacklevel) + except ValueError: + # being called from less than 3 (or given) stacklevels, weird, + # but don't crash + stacklevel = 0 + except: + # _getframe() doesn't work, weird interpreter issue, weird, + # ok, but don't crash + stacklevel = 0 + else: + # using __name__ here requires that we have __name__ in the + # __globals__ of the decorated string functions we make also. + # we generate this using {"__name__": fn.__module__} + while frame is not None and re.match( + r"^(?:sqlalchemy\.|alembic\.)", frame.f_globals.get("__name__", "") + ): + frame = frame.f_back + stacklevel += 1 + + if category is not None: + warnings.warn(message, category, stacklevel=stacklevel + 1) + else: + warnings.warn(message, stacklevel=stacklevel + 1) + + +def only_once(fn, retry_on_exception): + """Decorate the given function to be a no-op after it is called exactly + once.""" + + once = [fn] + + def go(*arg, **kw): + # strong reference fn so that it isn't garbage collected, + # which interferes with the event system's expectations + strong_fn = fn # noqa + if once: + once_fn = once.pop() + try: + return once_fn(*arg, **kw) + except: + if retry_on_exception: + once.insert(0, once_fn) + raise + + return go + + +_SQLA_RE = re.compile(r"sqlalchemy/([a-z_]+/){0,2}[a-z_]+\.py") +_UNITTEST_RE = re.compile(r"unit(?:2|test2?/)") + + +def chop_traceback(tb, exclude_prefix=_UNITTEST_RE, exclude_suffix=_SQLA_RE): + """Chop extraneous lines off beginning and end of a traceback. + + :param tb: + a list of traceback lines as returned by ``traceback.format_stack()`` + + :param exclude_prefix: + a regular expression object matching lines to skip at beginning of + ``tb`` + + :param exclude_suffix: + a regular expression object matching lines to skip at end of ``tb`` + """ + start = 0 + end = len(tb) - 1 + while start <= end and exclude_prefix.search(tb[start]): + start += 1 + while start <= end and exclude_suffix.search(tb[end]): + end -= 1 + return tb[start : end + 1] + + +NoneType = type(None) + + +def attrsetter(attrname): + code = "def set(obj, value):" " obj.%s = value" % attrname + env = locals().copy() + exec(code, env) + return env["set"] + + +class EnsureKWArgType(type): + r"""Apply translation of functions to accept \**kw arguments if they + don't already. + + """ + + def __init__(cls, clsname, bases, clsdict): + fn_reg = cls.ensure_kwarg + if fn_reg: + for key in clsdict: + m = re.match(fn_reg, key) + if m: + fn = clsdict[key] + spec = compat.inspect_getfullargspec(fn) + if not spec.varkw: + clsdict[key] = wrapped = cls._wrap_w_kw(fn) + setattr(cls, key, wrapped) + super(EnsureKWArgType, cls).__init__(clsname, bases, clsdict) + + def _wrap_w_kw(self, fn): + def wrap(*arg, **kw): + return fn(*arg) + + return update_wrapper(wrap, fn) + + +def wrap_callable(wrapper, fn): + """Augment functools.update_wrapper() to work with objects with + a ``__call__()`` method. + + :param fn: + object with __call__ method + + """ + if hasattr(fn, "__name__"): + return update_wrapper(wrapper, fn) + else: + _f = wrapper + _f.__name__ = fn.__class__.__name__ + if hasattr(fn, "__module__"): + _f.__module__ = fn.__module__ + + if hasattr(fn.__call__, "__doc__") and fn.__call__.__doc__: + _f.__doc__ = fn.__call__.__doc__ + elif fn.__doc__: + _f.__doc__ = fn.__doc__ + + return _f + + +def quoted_token_parser(value): + """Parse a dotted identifier with accommodation for quoted names. + + Includes support for SQL-style double quotes as a literal character. + + E.g.:: + + >>> quoted_token_parser("name") + ["name"] + >>> quoted_token_parser("schema.name") + ["schema", "name"] + >>> quoted_token_parser('"Schema"."Name"') + ['Schema', 'Name'] + >>> quoted_token_parser('"Schema"."Name""Foo"') + ['Schema', 'Name""Foo'] + + """ + + if '"' not in value: + return value.split(".") + + # 0 = outside of quotes + # 1 = inside of quotes + state = 0 + result = [[]] + idx = 0 + lv = len(value) + while idx < lv: + char = value[idx] + if char == '"': + if state == 1 and idx < lv - 1 and value[idx + 1] == '"': + result[-1].append('"') + idx += 1 + else: + state ^= 1 + elif char == "." and state == 0: + result.append([]) + else: + result[-1].append(char) + idx += 1 + + return ["".join(token) for token in result] + + +def add_parameter_text(params, text): + params = _collections.to_list(params) + + def decorate(fn): + doc = fn.__doc__ is not None and fn.__doc__ or "" + if doc: + doc = inject_param_text(doc, {param: text for param in params}) + fn.__doc__ = doc + return fn + + return decorate + + +def _dedent_docstring(text): + split_text = text.split("\n", 1) + if len(split_text) == 1: + return text + else: + firstline, remaining = split_text + if not firstline.startswith(" "): + return firstline + "\n" + textwrap.dedent(remaining) + else: + return textwrap.dedent(text) + + +def inject_docstring_text(doctext, injecttext, pos): + doctext = _dedent_docstring(doctext or "") + lines = doctext.split("\n") + if len(lines) == 1: + lines.append("") + injectlines = textwrap.dedent(injecttext).split("\n") + if injectlines[0]: + injectlines.insert(0, "") + + blanks = [num for num, line in enumerate(lines) if not line.strip()] + blanks.insert(0, 0) + + inject_pos = blanks[min(pos, len(blanks) - 1)] + + lines = lines[0:inject_pos] + injectlines + lines[inject_pos:] + return "\n".join(lines) + + +_param_reg = re.compile(r"(\s+):param (.+?):") + + +def inject_param_text(doctext, inject_params): + doclines = collections.deque(doctext.splitlines()) + lines = [] + + # TODO: this is not working for params like ":param case_sensitive=True:" + + to_inject = None + while doclines: + line = doclines.popleft() + + m = _param_reg.match(line) + + if to_inject is None: + if m: + param = m.group(2).lstrip("*") + if param in inject_params: + # default indent to that of :param: plus one + indent = " " * len(m.group(1)) + " " + + # but if the next line has text, use that line's + # indentation + if doclines: + m2 = re.match(r"(\s+)\S", doclines[0]) + if m2: + indent = " " * len(m2.group(1)) + + to_inject = indent + inject_params[param] + elif m: + lines.extend(["\n", to_inject, "\n"]) + to_inject = None + elif not line.rstrip(): + lines.extend([line, to_inject, "\n"]) + to_inject = None + elif line.endswith("::"): + # TODO: this still wont cover if the code example itself has blank + # lines in it, need to detect those via indentation. + lines.extend([line, doclines.popleft()]) + continue + lines.append(line) + + return "\n".join(lines) + + +def repr_tuple_names(names): + """Trims a list of strings from the middle and return a string of up to + four elements. Strings greater than 11 characters will be truncated""" + if len(names) == 0: + return None + flag = len(names) <= 4 + names = names[0:4] if flag else names[0:3] + names[-1:] + res = ["%s.." % name[:11] if len(name) > 11 else name for name in names] + if flag: + return ", ".join(res) + else: + return "%s, ..., %s" % (", ".join(res[0:3]), res[-1]) + + +def has_compiled_ext(): + try: + from sqlalchemy import cimmutabledict # noqa: F401 + from sqlalchemy import cprocessors # noqa: F401 + from sqlalchemy import cresultproxy # noqa: F401 + + return True + except ImportError: + return False diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py new file mode 100644 index 0000000..67c5219 --- /dev/null +++ b/lib/sqlalchemy/util/queue.py @@ -0,0 +1,291 @@ +# util/queue.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""An adaptation of Py2.3/2.4's Queue module which supports reentrant +behavior, using RLock instead of Lock for its mutex object. The +Queue object is used exclusively by the sqlalchemy.pool.QueuePool +class. + +This is to support the connection pool's usage of weakref callbacks to return +connections to the underlying Queue, which can in extremely +rare cases be invoked within the ``get()`` method of the Queue itself, +producing a ``put()`` inside the ``get()`` and therefore a reentrant +condition. + +""" + +from collections import deque +from time import time as _time + +from . import compat +from .compat import threading +from .concurrency import asyncio +from .concurrency import await_fallback +from .concurrency import await_only +from .langhelpers import memoized_property + + +__all__ = ["Empty", "Full", "Queue"] + + +class Empty(Exception): + "Exception raised by Queue.get(block=0)/get_nowait()." + + pass + + +class Full(Exception): + "Exception raised by Queue.put(block=0)/put_nowait()." + + pass + + +class Queue: + def __init__(self, maxsize=0, use_lifo=False): + """Initialize a queue object with a given maximum size. + + If `maxsize` is <= 0, the queue size is infinite. + + If `use_lifo` is True, this Queue acts like a Stack (LIFO). + """ + + self._init(maxsize) + # mutex must be held whenever the queue is mutating. All methods + # that acquire mutex must release it before returning. mutex + # is shared between the two conditions, so acquiring and + # releasing the conditions also acquires and releases mutex. + self.mutex = threading.RLock() + # Notify not_empty whenever an item is added to the queue; a + # thread waiting to get is notified then. + self.not_empty = threading.Condition(self.mutex) + # Notify not_full whenever an item is removed from the queue; + # a thread waiting to put is notified then. + self.not_full = threading.Condition(self.mutex) + # If this queue uses LIFO or FIFO + self.use_lifo = use_lifo + + def qsize(self): + """Return the approximate size of the queue (not reliable!).""" + + with self.mutex: + return self._qsize() + + def empty(self): + """Return True if the queue is empty, False otherwise (not + reliable!).""" + + with self.mutex: + return self._empty() + + def full(self): + """Return True if the queue is full, False otherwise (not + reliable!).""" + + with self.mutex: + return self._full() + + def put(self, item, block=True, timeout=None): + """Put an item into the queue. + + If optional args `block` is True and `timeout` is None (the + default), block if necessary until a free slot is + available. If `timeout` is a positive number, it blocks at + most `timeout` seconds and raises the ``Full`` exception if no + free slot was available within that time. Otherwise (`block` + is false), put an item on the queue if a free slot is + immediately available, else raise the ``Full`` exception + (`timeout` is ignored in that case). + """ + + with self.not_full: + if not block: + if self._full(): + raise Full + elif timeout is None: + while self._full(): + self.not_full.wait() + else: + if timeout < 0: + raise ValueError("'timeout' must be a positive number") + endtime = _time() + timeout + while self._full(): + remaining = endtime - _time() + if remaining <= 0.0: + raise Full + self.not_full.wait(remaining) + self._put(item) + self.not_empty.notify() + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + Only enqueue the item if a free slot is immediately available. + Otherwise raise the ``Full`` exception. + """ + return self.put(item, False) + + def get(self, block=True, timeout=None): + """Remove and return an item from the queue. + + If optional args `block` is True and `timeout` is None (the + default), block if necessary until an item is available. If + `timeout` is a positive number, it blocks at most `timeout` + seconds and raises the ``Empty`` exception if no item was + available within that time. Otherwise (`block` is false), + return an item if one is immediately available, else raise the + ``Empty`` exception (`timeout` is ignored in that case). + + """ + with self.not_empty: + if not block: + if self._empty(): + raise Empty + elif timeout is None: + while self._empty(): + self.not_empty.wait() + else: + if timeout < 0: + raise ValueError("'timeout' must be a positive number") + endtime = _time() + timeout + while self._empty(): + remaining = endtime - _time() + if remaining <= 0.0: + raise Empty + self.not_empty.wait(remaining) + item = self._get() + self.not_full.notify() + return item + + def get_nowait(self): + """Remove and return an item from the queue without blocking. + + Only get an item if one is immediately available. Otherwise + raise the ``Empty`` exception. + """ + + return self.get(False) + + # Override these methods to implement other queue organizations + # (e.g. stack or priority queue). + # These will only be called with appropriate locks held + + # Initialize the queue representation + def _init(self, maxsize): + self.maxsize = maxsize + self.queue = deque() + + def _qsize(self): + return len(self.queue) + + # Check whether the queue is empty + def _empty(self): + return not self.queue + + # Check whether the queue is full + def _full(self): + return self.maxsize > 0 and len(self.queue) == self.maxsize + + # Put a new item in the queue + def _put(self, item): + self.queue.append(item) + + # Get an item from the queue + def _get(self): + if self.use_lifo: + # LIFO + return self.queue.pop() + else: + # FIFO + return self.queue.popleft() + + +class AsyncAdaptedQueue: + await_ = staticmethod(await_only) + + def __init__(self, maxsize=0, use_lifo=False): + self.use_lifo = use_lifo + self.maxsize = maxsize + + def empty(self): + return self._queue.empty() + + def full(self): + return self._queue.full() + + def qsize(self): + return self._queue.qsize() + + @memoized_property + def _queue(self): + # Delay creation of the queue until it is first used, to avoid + # binding it to a possibly wrong event loop. + # By delaying the creation of the pool we accommodate the common + # usage pattern of instantiating the engine at module level, where a + # different event loop is in present compared to when the application + # is actually run. + + if self.use_lifo: + queue = asyncio.LifoQueue(maxsize=self.maxsize) + else: + queue = asyncio.Queue(maxsize=self.maxsize) + return queue + + def put_nowait(self, item): + try: + return self._queue.put_nowait(item) + except asyncio.QueueFull as err: + compat.raise_( + Full(), + replace_context=err, + ) + + def put(self, item, block=True, timeout=None): + if not block: + return self.put_nowait(item) + + try: + if timeout is not None: + return self.await_( + asyncio.wait_for(self._queue.put(item), timeout) + ) + else: + return self.await_(self._queue.put(item)) + except (asyncio.QueueFull, asyncio.TimeoutError) as err: + compat.raise_( + Full(), + replace_context=err, + ) + + def get_nowait(self): + try: + return self._queue.get_nowait() + except asyncio.QueueEmpty as err: + compat.raise_( + Empty(), + replace_context=err, + ) + + def get(self, block=True, timeout=None): + if not block: + return self.get_nowait() + + try: + if timeout is not None: + return self.await_( + asyncio.wait_for(self._queue.get(), timeout) + ) + else: + return self.await_(self._queue.get()) + except (asyncio.QueueEmpty, asyncio.TimeoutError) as err: + compat.raise_( + Empty(), + replace_context=err, + ) + + +class FallbackAsyncAdaptedQueue(AsyncAdaptedQueue): + await_ = staticmethod(await_fallback) diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py new file mode 100644 index 0000000..bbc819f --- /dev/null +++ b/lib/sqlalchemy/util/topological.py @@ -0,0 +1,100 @@ +# util/topological.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Topological sorting algorithms.""" + +from .. import util +from ..exc import CircularDependencyError + +__all__ = ["sort", "sort_as_subsets", "find_cycles"] + + +def sort_as_subsets(tuples, allitems): + + edges = util.defaultdict(set) + for parent, child in tuples: + edges[child].add(parent) + + todo = list(allitems) + todo_set = set(allitems) + + while todo_set: + output = [] + for node in todo: + if todo_set.isdisjoint(edges[node]): + output.append(node) + + if not output: + raise CircularDependencyError( + "Circular dependency detected.", + find_cycles(tuples, allitems), + _gen_edges(edges), + ) + + todo_set.difference_update(output) + todo = [t for t in todo if t in todo_set] + yield output + + +def sort(tuples, allitems, deterministic_order=True): + """sort the given list of items by dependency. + + 'tuples' is a list of tuples representing a partial ordering. + + deterministic_order is no longer used, the order is now always + deterministic given the order of "allitems". the flag is there + for backwards compatibility with Alembic. + + """ + + for set_ in sort_as_subsets(tuples, allitems): + for s in set_: + yield s + + +def find_cycles(tuples, allitems): + # adapted from: + # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html + + edges = util.defaultdict(set) + for parent, child in tuples: + edges[parent].add(child) + nodes_to_test = set(edges) + + output = set() + + # we'd like to find all nodes that are + # involved in cycles, so we do the full + # pass through the whole thing for each + # node in the original list. + + # we can go just through parent edge nodes. + # if a node is only a child and never a parent, + # by definition it can't be part of a cycle. same + # if it's not in the edges at all. + for node in nodes_to_test: + stack = [node] + todo = nodes_to_test.difference(stack) + while stack: + top = stack[-1] + for node in edges[top]: + if node in stack: + cyc = stack[stack.index(node) :] + todo.difference_update(cyc) + output.update(cyc) + + if node in todo: + stack.append(node) + todo.remove(node) + break + else: + node = stack.pop() + return output + + +def _gen_edges(edges): + return set([(right, left) for left in edges for right in edges[left]]) |