diff options
Diffstat (limited to 'lib/sqlalchemy/engine/util.py')
-rw-r--r-- | lib/sqlalchemy/engine/util.py | 253 |
1 files changed, 253 insertions, 0 deletions
diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py new file mode 100644 index 0000000..1b03ebb --- /dev/null +++ b/lib/sqlalchemy/engine/util.py @@ -0,0 +1,253 @@ +# engine/util.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .. import exc +from .. import util +from ..util import collections_abc +from ..util import immutabledict + + +def connection_memoize(key): + """Decorator, memoize a function in a connection.info stash. + + Only applicable to functions which take no arguments other than a + connection. The memo will be stored in ``connection.info[key]``. + """ + + @util.decorator + def decorated(fn, self, connection): + connection = connection.connect() + try: + return connection.info[key] + except KeyError: + connection.info[key] = val = fn(self, connection) + return val + + return decorated + + +_no_tuple = () +_no_kw = util.immutabledict() + + +def _distill_params(connection, multiparams, params): + r"""Given arguments from the calling form \*multiparams, \**params, + return a list of bind parameter structures, usually a list of + dictionaries. + + In the case of 'raw' execution which accepts positional parameters, + it may be a list of tuples or lists. + + """ + + if not multiparams: + if params: + connection._warn_for_legacy_exec_format() + return [params] + else: + return [] + elif len(multiparams) == 1: + zero = multiparams[0] + if isinstance(zero, (list, tuple)): + if ( + not zero + or hasattr(zero[0], "__iter__") + and not hasattr(zero[0], "strip") + ): + # execute(stmt, [{}, {}, {}, ...]) + # execute(stmt, [(), (), (), ...]) + return zero + else: + # this is used by exec_driver_sql only, so a deprecation + # warning would already be coming from passing a plain + # textual statement with positional parameters to + # execute(). + # execute(stmt, ("value", "value")) + return [zero] + elif hasattr(zero, "keys"): + # execute(stmt, {"key":"value"}) + return [zero] + else: + connection._warn_for_legacy_exec_format() + # execute(stmt, "value") + return [[zero]] + else: + connection._warn_for_legacy_exec_format() + if hasattr(multiparams[0], "__iter__") and not hasattr( + multiparams[0], "strip" + ): + return multiparams + else: + return [multiparams] + + +def _distill_cursor_params(connection, multiparams, params): + """_distill_params without any warnings. more appropriate for + "cursor" params that can include tuple arguments, lists of tuples, + etc. + + """ + + if not multiparams: + if params: + return [params] + else: + return [] + elif len(multiparams) == 1: + zero = multiparams[0] + if isinstance(zero, (list, tuple)): + if ( + not zero + or hasattr(zero[0], "__iter__") + and not hasattr(zero[0], "strip") + ): + # execute(stmt, [{}, {}, {}, ...]) + # execute(stmt, [(), (), (), ...]) + return zero + else: + # this is used by exec_driver_sql only, so a deprecation + # warning would already be coming from passing a plain + # textual statement with positional parameters to + # execute(). + # execute(stmt, ("value", "value")) + + return [zero] + elif hasattr(zero, "keys"): + # execute(stmt, {"key":"value"}) + return [zero] + else: + # execute(stmt, "value") + return [[zero]] + else: + if hasattr(multiparams[0], "__iter__") and not hasattr( + multiparams[0], "strip" + ): + return multiparams + else: + return [multiparams] + + +def _distill_params_20(params): + if params is None: + return _no_tuple, _no_kw + elif isinstance(params, list): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ + if params and not isinstance( + params[0], (collections_abc.Mapping, tuple) + ): + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + + return (params,), _no_kw + elif isinstance( + params, + (tuple, dict, immutabledict), + # only do abc.__instancecheck__ for Mapping after we've checked + # for plain dictionaries and would otherwise raise + ) or isinstance(params, collections_abc.Mapping): + return (params,), _no_kw + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") + + +class TransactionalContext(object): + """Apply Python context manager behavior to transaction objects. + + Performs validation to ensure the subject of the transaction is not + used if the transaction were ended prematurely. + + """ + + _trans_subject = None + + def _transaction_is_active(self): + raise NotImplementedError() + + def _transaction_is_closed(self): + raise NotImplementedError() + + def _rollback_can_be_called(self): + """indicates the object is in a state that is known to be acceptable + for rollback() to be called. + + This does not necessarily mean rollback() will succeed or not raise + an error, just that there is currently no state detected that indicates + rollback() would fail or emit warnings. + + It also does not mean that there's a transaction in progress, as + it is usually safe to call rollback() even if no transaction is + present. + + .. versionadded:: 1.4.28 + + """ + raise NotImplementedError() + + def _get_subject(self): + raise NotImplementedError() + + @classmethod + def _trans_ctx_check(cls, subject): + trans_context = subject._trans_context_manager + if trans_context: + if not trans_context._transaction_is_active(): + raise exc.InvalidRequestError( + "Can't operate on closed transaction inside context " + "manager. Please complete the context manager " + "before emitting further commands." + ) + + def __enter__(self): + subject = self._get_subject() + + # none for outer transaction, may be non-None for nested + # savepoint, legacy nesting cases + trans_context = subject._trans_context_manager + self._outer_trans_ctx = trans_context + + self._trans_subject = subject + subject._trans_context_manager = self + return self + + def __exit__(self, type_, value, traceback): + subject = self._trans_subject + + # simplistically we could assume that + # "subject._trans_context_manager is self". However, any calling + # code that is manipulating __exit__ directly would break this + # assumption. alembic context manager + # is an example of partial use that just calls __exit__ and + # not __enter__ at the moment. it's safe to assume this is being done + # in the wild also + out_of_band_exit = ( + subject is None or subject._trans_context_manager is not self + ) + + if type_ is None and self._transaction_is_active(): + try: + self.commit() + except: + with util.safe_reraise(): + if self._rollback_can_be_called(): + self.rollback() + finally: + if not out_of_band_exit: + subject._trans_context_manager = self._outer_trans_ctx + self._trans_subject = self._outer_trans_ctx = None + else: + try: + if not self._transaction_is_active(): + if not self._transaction_is_closed(): + self.close() + else: + if self._rollback_can_be_called(): + self.rollback() + finally: + if not out_of_band_exit: + subject._trans_context_manager = self._outer_trans_ctx + self._trans_subject = self._outer_trans_ctx = None |