diff options
author | xiubuzhe <xiubuzhe@sina.com> | 2023-10-08 20:59:00 +0800 |
---|---|---|
committer | xiubuzhe <xiubuzhe@sina.com> | 2023-10-08 20:59:00 +0800 |
commit | 1dac2263372df2b85db5d029a45721fa158a5c9d (patch) | |
tree | 0365f9c57df04178a726d7584ca6a6b955a7ce6a /lib/sqlalchemy/testing/assertions.py | |
parent | b494be364bb39e1de128ada7dc576a729d99907e (diff) | |
download | sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.gz sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.bz2 sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.zip |
first add files
Diffstat (limited to 'lib/sqlalchemy/testing/assertions.py')
-rw-r--r-- | lib/sqlalchemy/testing/assertions.py | 845 |
1 files changed, 845 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py new file mode 100644 index 0000000..9a3c06b --- /dev/null +++ b/lib/sqlalchemy/testing/assertions.py @@ -0,0 +1,845 @@ +# testing/assertions.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +import contextlib +import re +import sys +import warnings + +from . import assertsql +from . import config +from . import engines +from . import mock +from .exclusions import db_spec +from .util import fail +from .. import exc as sa_exc +from .. import schema +from .. import sql +from .. import types as sqltypes +from .. import util +from ..engine import default +from ..engine import url +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import compat +from ..util import decorator + + +def expect_warnings(*messages, **kw): + """Context manager which expects one or more warnings. + + With no arguments, squelches all SAWarning and RemovedIn20Warning emitted via + sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise + pass string expressions that will match selected warnings via regex; + all non-matching warnings are sent through. + + The expect version **asserts** that the warnings were in fact seen. + + Note that the test suite sets SAWarning warnings to raise exceptions. + + """ # noqa + return _expect_warnings( + (sa_exc.RemovedIn20Warning, sa_exc.SAWarning), messages, **kw + ) + + +@contextlib.contextmanager +def expect_warnings_on(db, *messages, **kw): + """Context manager which expects one or more warnings on specific + dialects. + + The expect version **asserts** that the warnings were in fact seen. + + """ + spec = db_spec(db) + + if isinstance(db, util.string_types) and not spec(config._current): + yield + else: + with expect_warnings(*messages, **kw): + yield + + +def emits_warning(*messages): + """Decorator form of expect_warnings(). + + Note that emits_warning does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_warnings(assert_=False, *messages): + return fn(*args, **kw) + + return decorate + + +def expect_deprecated(*messages, **kw): + return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw) + + +def expect_deprecated_20(*messages, **kw): + return _expect_warnings(sa_exc.Base20DeprecationWarning, messages, **kw) + + +def emits_warning_on(db, *messages): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + + Note that emits_warning_on does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_warnings_on(db, assert_=False, *messages): + return fn(*args, **kw) + + return decorate + + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + + Note that uses_deprecated does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_deprecated(*messages, assert_=False): + return fn(*args, **kw) + + return decorate + + +_FILTERS = None +_SEEN = None +_EXC_CLS = None + + +@contextlib.contextmanager +def _expect_warnings( + exc_cls, + messages, + regex=True, + search_msg=False, + assert_=True, + py2konly=False, + raise_on_any_unexpected=False, + squelch_other_warnings=False, +): + + global _FILTERS, _SEEN, _EXC_CLS + + if regex or search_msg: + filters = [re.compile(msg, re.I | re.S) for msg in messages] + else: + filters = list(messages) + + if _FILTERS is not None: + # nested call; update _FILTERS and _SEEN, return. outer + # block will assert our messages + assert _SEEN is not None + assert _EXC_CLS is not None + _FILTERS.extend(filters) + _SEEN.update(filters) + _EXC_CLS += (exc_cls,) + yield + else: + seen = _SEEN = set(filters) + _FILTERS = filters + _EXC_CLS = (exc_cls,) + + if raise_on_any_unexpected: + + def real_warn(msg, *arg, **kw): + raise AssertionError("Got unexpected warning: %r" % msg) + + else: + real_warn = warnings.warn + + def our_warn(msg, *arg, **kw): + + if isinstance(msg, _EXC_CLS): + exception = type(msg) + msg = str(msg) + elif arg: + exception = arg[0] + else: + exception = None + + if not exception or not issubclass(exception, _EXC_CLS): + if not squelch_other_warnings: + return real_warn(msg, *arg, **kw) + else: + return + + if not filters and not raise_on_any_unexpected: + return + + for filter_ in filters: + if ( + (search_msg and filter_.search(msg)) + or (regex and filter_.match(msg)) + or (not regex and filter_ == msg) + ): + seen.discard(filter_) + break + else: + if not squelch_other_warnings: + real_warn(msg, *arg, **kw) + + with mock.patch("warnings.warn", our_warn), mock.patch( + "sqlalchemy.util.SQLALCHEMY_WARN_20", True + ), mock.patch( + "sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20", True + ), mock.patch( + "sqlalchemy.engine.row.LegacyRow._default_key_style", 2 + ): + try: + yield + finally: + _SEEN = _FILTERS = _EXC_CLS = None + + if assert_ and (not py2konly or not compat.py3k): + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) + + +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + _assert_no_stray_pool_connections() + + +def _assert_no_stray_pool_connections(): + engines.testing_reaper.assert_all_closed() + + +def eq_regex(a, b, msg=None): + assert re.match(b, a), msg or "%r !~ %r" % (a, b) + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + + +def le_(a, b, msg=None): + """Assert a <= b, with repr messaging on failure.""" + assert a <= b, msg or "%r != %r" % (a, b) + + +def is_instance_of(a, b, msg=None): + assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b) + + +def is_none(a, msg=None): + is_(a, None, msg=msg) + + +def is_not_none(a, msg=None): + is_not(a, None, msg=msg) + + +def is_true(a, msg=None): + is_(bool(a), True, msg=msg) + + +def is_false(a, msg=None): + is_(bool(a), False, msg=msg) + + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + + +def is_not(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + + +# deprecated. See #5429 +is_not_ = is_not + + +def in_(a, b, msg=None): + """Assert a in b, with repr messaging on failure.""" + assert a in b, msg or "%r not in %r" % (a, b) + + +def not_in(a, b, msg=None): + """Assert a in not b, with repr messaging on failure.""" + assert a not in b, msg or "%r is in %r" % (a, b) + + +# deprecated. See #5429 +not_in_ = not_in + + +def startswith_(a, fragment, msg=None): + """Assert a.startswith(fragment), with repr messaging on failure.""" + assert a.startswith(fragment), msg or "%r does not start with %r" % ( + a, + fragment, + ) + + +def eq_ignore_whitespace(a, b, msg=None): + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) + + assert a == b, msg or "%r != %r" % (a, b) + + +def _assert_proper_exception_context(exception): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. we want + these exceptions in a cause chain. + + """ + + if not util.py3k: + return + + if ( + exception.__context__ is not exception.__cause__ + and not exception.__suppress_context__ + ): + assert False, ( + "Exception %r was correctly raised but did not set a cause, " + "within context %r as its cause." + % (exception, exception.__context__) + ) + + +def assert_raises(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw, check_context=True) + + +def assert_raises_context_ok(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw) + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + return _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, check_context=True + ) + + +def assert_warns(except_cls, callable_, *args, **kwargs): + """legacy adapter function for functions that were previously using + assert_raises with SAWarning or similar. + + has some workarounds to accommodate the fact that the callable completes + with this approach rather than stopping at the exception raise. + + + """ + with _expect_warnings(except_cls, [".*"], squelch_other_warnings=True): + return callable_(*args, **kwargs) + + +def assert_warns_message(except_cls, msg, callable_, *args, **kwargs): + """legacy adapter function for functions that were previously using + assert_raises with SAWarning or similar. + + has some workarounds to accommodate the fact that the callable completes + with this approach rather than stopping at the exception raise. + + Also uses regex.search() to match the given message to the error string + rather than regex.match(). + + """ + with _expect_warnings( + except_cls, + [msg], + search_msg=True, + regex=False, + squelch_other_warnings=True, + ): + return callable_(*args, **kwargs) + + +def assert_raises_message_context_ok( + except_cls, msg, callable_, *args, **kwargs +): + return _assert_raises(except_cls, callable_, args, kwargs, msg=msg) + + +def _assert_raises( + except_cls, callable_, args, kwargs, msg=None, check_context=False +): + + with _expect_raises(except_cls, msg, check_context) as ec: + callable_(*args, **kwargs) + return ec.error + + +class _ErrorContainer(object): + error = None + + +@contextlib.contextmanager +def _expect_raises(except_cls, msg=None, check_context=False): + if ( + isinstance(except_cls, type) + and issubclass(except_cls, Warning) + or isinstance(except_cls, Warning) + ): + raise TypeError( + "Use expect_warnings for warnings, not " + "expect_raises / assert_raises" + ) + ec = _ErrorContainer() + if check_context: + are_we_already_in_a_traceback = sys.exc_info()[0] + try: + yield ec + success = False + except except_cls as err: + ec.error = err + success = True + if msg is not None: + assert re.search( + msg, util.text_type(err), re.UNICODE + ), "%r !~ %s" % (msg, err) + if check_context and not are_we_already_in_a_traceback: + _assert_proper_exception_context(err) + print(util.text_type(err).encode("utf-8")) + + # it's generally a good idea to not carry traceback objects outside + # of the except: block, but in this case especially we seem to have + # hit some bug in either python 3.10.0b2 or greenlet or both which + # this seems to fix: + # https://github.com/python-greenlet/greenlet/issues/242 + del ec + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + + +def expect_raises(except_cls, check_context=True): + return _expect_raises(except_cls, check_context=check_context) + + +def expect_raises_message(except_cls, msg, check_context=True): + return _expect_raises(except_cls, msg=msg, check_context=check_context) + + +class AssertsCompiledSQL(object): + def assert_compile( + self, + clause, + result, + params=None, + checkparams=None, + for_executemany=False, + check_literal_execute=None, + check_post_param=None, + dialect=None, + checkpositional=None, + check_prefetch=None, + use_default_dialect=False, + allow_dialect_select=False, + supports_default_values=True, + supports_default_metavalue=True, + literal_binds=False, + render_postcompile=False, + schema_translate_map=None, + render_schema_translate=False, + default_schema_name=None, + from_linting=False, + ): + if use_default_dialect: + dialect = default.DefaultDialect() + dialect.supports_default_values = supports_default_values + dialect.supports_default_metavalue = supports_default_metavalue + elif allow_dialect_select: + dialect = None + else: + if dialect is None: + dialect = getattr(self, "__dialect__", None) + + if dialect is None: + dialect = config.db.dialect + elif dialect == "default": + dialect = default.DefaultDialect() + dialect.supports_default_values = supports_default_values + dialect.supports_default_metavalue = supports_default_metavalue + elif dialect == "default_enhanced": + dialect = default.StrCompileDialect() + elif isinstance(dialect, util.string_types): + dialect = url.URL.create(dialect).get_dialect()() + + if default_schema_name: + dialect.default_schema_name = default_schema_name + + kw = {} + compile_kwargs = {} + + if schema_translate_map: + kw["schema_translate_map"] = schema_translate_map + + if params is not None: + kw["column_keys"] = list(params) + + if literal_binds: + compile_kwargs["literal_binds"] = True + + if render_postcompile: + compile_kwargs["render_postcompile"] = True + + if for_executemany: + kw["for_executemany"] = True + + if render_schema_translate: + kw["render_schema_translate"] = True + + if from_linting or getattr(self, "assert_from_linting", False): + kw["linting"] = sql.FROM_LINTING + + from sqlalchemy import orm + + if isinstance(clause, orm.Query): + stmt = clause._statement_20() + stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL + clause = stmt + + if compile_kwargs: + kw["compile_kwargs"] = compile_kwargs + + class DontAccess(object): + def __getattribute__(self, key): + raise NotImplementedError( + "compiler accessed .statement; use " + "compiler.current_executable" + ) + + class CheckCompilerAccess(object): + def __init__(self, test_statement): + self.test_statement = test_statement + self._annotations = {} + self.supports_execution = getattr( + test_statement, "supports_execution", False + ) + + if self.supports_execution: + self._execution_options = test_statement._execution_options + + if hasattr(test_statement, "_returning"): + self._returning = test_statement._returning + if hasattr(test_statement, "_inline"): + self._inline = test_statement._inline + if hasattr(test_statement, "_return_defaults"): + self._return_defaults = test_statement._return_defaults + + def _default_dialect(self): + return self.test_statement._default_dialect() + + def compile(self, dialect, **kw): + return self.test_statement.compile.__func__( + self, dialect=dialect, **kw + ) + + def _compiler(self, dialect, **kw): + return self.test_statement._compiler.__func__( + self, dialect, **kw + ) + + def _compiler_dispatch(self, compiler, **kwargs): + if hasattr(compiler, "statement"): + with mock.patch.object( + compiler, "statement", DontAccess() + ): + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + else: + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + + # no construct can assume it's the "top level" construct in all cases + # as anything can be nested. ensure constructs don't assume they + # are the "self.statement" element + c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw) + + if isinstance(clause, sqltypes.TypeEngine): + cache_key_no_warnings = clause._static_cache_key + if cache_key_no_warnings: + hash(cache_key_no_warnings) + else: + cache_key_no_warnings = clause._generate_cache_key() + if cache_key_no_warnings: + hash(cache_key_no_warnings[0]) + + param_str = repr(getattr(c, "params", {})) + if util.py3k: + param_str = param_str.encode("utf-8").decode("ascii", "ignore") + print( + ("\nSQL String:\n" + util.text_type(c) + param_str).encode( + "utf-8" + ) + ) + else: + print( + "\nSQL String:\n" + + util.text_type(c).encode("utf-8") + + param_str + ) + + cc = re.sub(r"[\n\t]", "", util.text_type(c)) + + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + + if checkparams is not None: + eq_(c.construct_params(params), checkparams) + if checkpositional is not None: + p = c.construct_params(params) + eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + if check_prefetch is not None: + eq_(c.prefetch, check_prefetch) + if check_literal_execute is not None: + eq_( + { + c.bind_names[b]: b.effective_value + for b in c.literal_execute_params + }, + check_literal_execute, + ) + if check_post_param is not None: + eq_( + { + c.bind_names[b]: b.effective_value + for b in c.post_compile_params + }, + check_post_param, + ) + + +class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table, strict_types=False): + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + eq_(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) + + if strict_types: + msg = "Type '%s' doesn't correspond to type '%s'" + assert isinstance(reflected_c.type, type(c.type)), msg % ( + reflected_c.type, + c.type, + ) + else: + self.assert_types_base(reflected_c, c) + + if isinstance(c.type, sqltypes.String): + eq_(c.type.length, reflected_c.type.length) + + eq_( + {f.column.name for f in c.foreign_keys}, + {f.column.name for f in reflected_c.foreign_keys}, + ) + if c.server_default: + assert isinstance( + reflected_c.server_default, schema.FetchedValue + ) + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None + + def assert_types_base(self, c1, c2): + assert c1.type._compare_type_affinity( + c2.type + ), "On column %r, type '%s' doesn't correspond to type '%s'" % ( + c1.name, + c1.type, + c2.type, + ) + + +class AssertsExecutionResults(object): + def assert_result(self, result, class_, *objects): + result = list(result) + print(repr(result)) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list_): + self.assert_( + len(result) == len(list_), + "result list is not the same size as test list, " + + "for class " + + class_.__name__, + ) + for i in range(0, len(list_)): + self.assert_row(class_, result[i], list_[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_( + rowobj.__class__ is class_, "item class is not " + repr(class_) + ) + for key, value in desc.items(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_( + getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" + % (key, getattr(rowobj, key), value), + ) + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + class immutabledict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = {immutabledict(e) for e in expected} + + for wrong in util.itertools_filterfalse( + lambda o: isinstance(o, cls), found + ): + fail( + 'Unexpected type "%s", expected "%s"' + % (type(wrong).__name__, cls.__name__) + ) + + if len(found) != len(expected): + fail( + 'Unexpected object count "%s", expected "%s"' + % (len(found), len(expected)) + ) + + NOVALUE = object() + + def _compare_item(obj, spec): + for key, value in spec.items(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1] + ) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + fail( + "Expected %s instance with attributes %s not found." + % (cls.__name__, repr(expected_item)) + ) + return True + + def sql_execution_asserter(self, db=None): + if db is None: + from . import db as db + + return assertsql.assert_engine(db) + + def assert_sql_execution(self, db, callable_, *rules): + with self.sql_execution_asserter(db) as asserter: + result = callable_() + asserter.assert_(*rules) + return result + + def assert_sql(self, db, callable_, rules): + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf( + *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] + ) + else: + newrule = assertsql.CompiledSQL(*rule) + newrules.append(newrule) + + return self.assert_sql_execution(db, callable_, *newrules) + + def assert_sql_count(self, db, callable_, count): + self.assert_sql_execution( + db, callable_, assertsql.CountStatements(count) + ) + + def assert_multiple_sql_count(self, dbs, callable_, counts): + recs = [ + (self.sql_execution_asserter(db), db, count) + for (db, count) in zip(dbs, counts) + ] + asserters = [] + for ctx, db, count in recs: + asserters.append(ctx.__enter__()) + try: + return callable_() + finally: + for asserter, (ctx, db, count) in zip(asserters, recs): + ctx.__exit__(None, None, None) + asserter.assert_(assertsql.CountStatements(count)) + + @contextlib.contextmanager + def assert_execution(self, db, *rules): + with self.sql_execution_asserter(db) as asserter: + yield + asserter.assert_(*rules) + + def assert_statement_count(self, db, count): + return self.assert_execution(db, assertsql.CountStatements(count)) |