diff options
Diffstat (limited to 'lib/sqlalchemy/testing')
36 files changed, 16260 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py new file mode 100644 index 0000000..80d344f --- /dev/null +++ b/lib/sqlalchemy/testing/__init__.py @@ -0,0 +1,86 @@ +# testing/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +from . import config +from . import mock +from .assertions import assert_raises +from .assertions import assert_raises_context_ok +from .assertions import assert_raises_message +from .assertions import assert_raises_message_context_ok +from .assertions import assert_warns +from .assertions import assert_warns_message +from .assertions import AssertsCompiledSQL +from .assertions import AssertsExecutionResults +from .assertions import ComparesTables +from .assertions import emits_warning +from .assertions import emits_warning_on +from .assertions import eq_ +from .assertions import eq_ignore_whitespace +from .assertions import eq_regex +from .assertions import expect_deprecated +from .assertions import expect_deprecated_20 +from .assertions import expect_raises +from .assertions import expect_raises_message +from .assertions import expect_warnings +from .assertions import in_ +from .assertions import is_ +from .assertions import is_false +from .assertions import is_instance_of +from .assertions import is_none +from .assertions import is_not +from .assertions import is_not_ +from .assertions import is_not_none +from .assertions import is_true +from .assertions import le_ +from .assertions import ne_ +from .assertions import not_in +from .assertions import not_in_ +from .assertions import startswith_ +from .assertions import uses_deprecated +from .config import async_test +from .config import combinations +from .config import combinations_list +from .config import db +from .config import fixture +from .config import requirements as requires +from .exclusions import _is_excluded +from .exclusions import _server_version +from .exclusions import against as _against +from .exclusions import db_spec +from .exclusions import exclude +from .exclusions import fails +from .exclusions import fails_if +from .exclusions import fails_on +from .exclusions import fails_on_everything_except +from .exclusions import future +from .exclusions import only_if +from .exclusions import only_on +from .exclusions import skip +from .exclusions import skip_if +from .schema import eq_clause_element +from .schema import eq_type_affinity +from .util import adict +from .util import fail +from .util import flag_combinations +from .util import force_drop_names +from .util import lambda_combinations +from .util import metadata_fixture +from .util import provide_metadata +from .util import resolve_lambda +from .util import rowset +from .util import run_as_contextmanager +from .util import teardown_events +from .warnings import assert_warnings +from .warnings import warn_test_suite + + +def against(*queries): + return _against(config._current, *queries) + + +crashes = skip diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py new file mode 100644 index 0000000..9a3c06b --- /dev/null +++ b/lib/sqlalchemy/testing/assertions.py @@ -0,0 +1,845 @@ +# testing/assertions.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +import contextlib +import re +import sys +import warnings + +from . import assertsql +from . import config +from . import engines +from . import mock +from .exclusions import db_spec +from .util import fail +from .. import exc as sa_exc +from .. import schema +from .. import sql +from .. import types as sqltypes +from .. import util +from ..engine import default +from ..engine import url +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import compat +from ..util import decorator + + +def expect_warnings(*messages, **kw): + """Context manager which expects one or more warnings. + + With no arguments, squelches all SAWarning and RemovedIn20Warning emitted via + sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise + pass string expressions that will match selected warnings via regex; + all non-matching warnings are sent through. + + The expect version **asserts** that the warnings were in fact seen. + + Note that the test suite sets SAWarning warnings to raise exceptions. + + """ # noqa + return _expect_warnings( + (sa_exc.RemovedIn20Warning, sa_exc.SAWarning), messages, **kw + ) + + +@contextlib.contextmanager +def expect_warnings_on(db, *messages, **kw): + """Context manager which expects one or more warnings on specific + dialects. + + The expect version **asserts** that the warnings were in fact seen. + + """ + spec = db_spec(db) + + if isinstance(db, util.string_types) and not spec(config._current): + yield + else: + with expect_warnings(*messages, **kw): + yield + + +def emits_warning(*messages): + """Decorator form of expect_warnings(). + + Note that emits_warning does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_warnings(assert_=False, *messages): + return fn(*args, **kw) + + return decorate + + +def expect_deprecated(*messages, **kw): + return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw) + + +def expect_deprecated_20(*messages, **kw): + return _expect_warnings(sa_exc.Base20DeprecationWarning, messages, **kw) + + +def emits_warning_on(db, *messages): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + + Note that emits_warning_on does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_warnings_on(db, assert_=False, *messages): + return fn(*args, **kw) + + return decorate + + +def uses_deprecated(*messages): + """Mark a test as immune from fatal deprecation warnings. + + With no arguments, squelches all SADeprecationWarning failures. + Or pass one or more strings; these will be matched to the root + of the warning description by warnings.filterwarnings(). + + As a special case, you may pass a function name prefixed with // + and it will be re-written as needed to match the standard warning + verbiage emitted by the sqlalchemy.util.deprecated decorator. + + Note that uses_deprecated does **not** assert that the warnings + were in fact seen. + + """ + + @decorator + def decorate(fn, *args, **kw): + with expect_deprecated(*messages, assert_=False): + return fn(*args, **kw) + + return decorate + + +_FILTERS = None +_SEEN = None +_EXC_CLS = None + + +@contextlib.contextmanager +def _expect_warnings( + exc_cls, + messages, + regex=True, + search_msg=False, + assert_=True, + py2konly=False, + raise_on_any_unexpected=False, + squelch_other_warnings=False, +): + + global _FILTERS, _SEEN, _EXC_CLS + + if regex or search_msg: + filters = [re.compile(msg, re.I | re.S) for msg in messages] + else: + filters = list(messages) + + if _FILTERS is not None: + # nested call; update _FILTERS and _SEEN, return. outer + # block will assert our messages + assert _SEEN is not None + assert _EXC_CLS is not None + _FILTERS.extend(filters) + _SEEN.update(filters) + _EXC_CLS += (exc_cls,) + yield + else: + seen = _SEEN = set(filters) + _FILTERS = filters + _EXC_CLS = (exc_cls,) + + if raise_on_any_unexpected: + + def real_warn(msg, *arg, **kw): + raise AssertionError("Got unexpected warning: %r" % msg) + + else: + real_warn = warnings.warn + + def our_warn(msg, *arg, **kw): + + if isinstance(msg, _EXC_CLS): + exception = type(msg) + msg = str(msg) + elif arg: + exception = arg[0] + else: + exception = None + + if not exception or not issubclass(exception, _EXC_CLS): + if not squelch_other_warnings: + return real_warn(msg, *arg, **kw) + else: + return + + if not filters and not raise_on_any_unexpected: + return + + for filter_ in filters: + if ( + (search_msg and filter_.search(msg)) + or (regex and filter_.match(msg)) + or (not regex and filter_ == msg) + ): + seen.discard(filter_) + break + else: + if not squelch_other_warnings: + real_warn(msg, *arg, **kw) + + with mock.patch("warnings.warn", our_warn), mock.patch( + "sqlalchemy.util.SQLALCHEMY_WARN_20", True + ), mock.patch( + "sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20", True + ), mock.patch( + "sqlalchemy.engine.row.LegacyRow._default_key_style", 2 + ): + try: + yield + finally: + _SEEN = _FILTERS = _EXC_CLS = None + + if assert_ and (not py2konly or not compat.py3k): + assert not seen, "Warnings were not seen: %s" % ", ".join( + "%r" % (s.pattern if regex else s) for s in seen + ) + + +def global_cleanup_assertions(): + """Check things that have to be finalized at the end of a test suite. + + Hardcoded at the moment, a modular system can be built here + to support things like PG prepared transactions, tables all + dropped, etc. + + """ + _assert_no_stray_pool_connections() + + +def _assert_no_stray_pool_connections(): + engines.testing_reaper.assert_all_closed() + + +def eq_regex(a, b, msg=None): + assert re.match(b, a), msg or "%r !~ %r" % (a, b) + + +def eq_(a, b, msg=None): + """Assert a == b, with repr messaging on failure.""" + assert a == b, msg or "%r != %r" % (a, b) + + +def ne_(a, b, msg=None): + """Assert a != b, with repr messaging on failure.""" + assert a != b, msg or "%r == %r" % (a, b) + + +def le_(a, b, msg=None): + """Assert a <= b, with repr messaging on failure.""" + assert a <= b, msg or "%r != %r" % (a, b) + + +def is_instance_of(a, b, msg=None): + assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b) + + +def is_none(a, msg=None): + is_(a, None, msg=msg) + + +def is_not_none(a, msg=None): + is_not(a, None, msg=msg) + + +def is_true(a, msg=None): + is_(bool(a), True, msg=msg) + + +def is_false(a, msg=None): + is_(bool(a), False, msg=msg) + + +def is_(a, b, msg=None): + """Assert a is b, with repr messaging on failure.""" + assert a is b, msg or "%r is not %r" % (a, b) + + +def is_not(a, b, msg=None): + """Assert a is not b, with repr messaging on failure.""" + assert a is not b, msg or "%r is %r" % (a, b) + + +# deprecated. See #5429 +is_not_ = is_not + + +def in_(a, b, msg=None): + """Assert a in b, with repr messaging on failure.""" + assert a in b, msg or "%r not in %r" % (a, b) + + +def not_in(a, b, msg=None): + """Assert a in not b, with repr messaging on failure.""" + assert a not in b, msg or "%r is in %r" % (a, b) + + +# deprecated. See #5429 +not_in_ = not_in + + +def startswith_(a, fragment, msg=None): + """Assert a.startswith(fragment), with repr messaging on failure.""" + assert a.startswith(fragment), msg or "%r does not start with %r" % ( + a, + fragment, + ) + + +def eq_ignore_whitespace(a, b, msg=None): + a = re.sub(r"^\s+?|\n", "", a) + a = re.sub(r" {2,}", " ", a) + b = re.sub(r"^\s+?|\n", "", b) + b = re.sub(r" {2,}", " ", b) + + assert a == b, msg or "%r != %r" % (a, b) + + +def _assert_proper_exception_context(exception): + """assert that any exception we're catching does not have a __context__ + without a __cause__, and that __suppress_context__ is never set. + + Python 3 will report nested as exceptions as "during the handling of + error X, error Y occurred". That's not what we want to do. we want + these exceptions in a cause chain. + + """ + + if not util.py3k: + return + + if ( + exception.__context__ is not exception.__cause__ + and not exception.__suppress_context__ + ): + assert False, ( + "Exception %r was correctly raised but did not set a cause, " + "within context %r as its cause." + % (exception, exception.__context__) + ) + + +def assert_raises(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw, check_context=True) + + +def assert_raises_context_ok(except_cls, callable_, *args, **kw): + return _assert_raises(except_cls, callable_, args, kw) + + +def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): + return _assert_raises( + except_cls, callable_, args, kwargs, msg=msg, check_context=True + ) + + +def assert_warns(except_cls, callable_, *args, **kwargs): + """legacy adapter function for functions that were previously using + assert_raises with SAWarning or similar. + + has some workarounds to accommodate the fact that the callable completes + with this approach rather than stopping at the exception raise. + + + """ + with _expect_warnings(except_cls, [".*"], squelch_other_warnings=True): + return callable_(*args, **kwargs) + + +def assert_warns_message(except_cls, msg, callable_, *args, **kwargs): + """legacy adapter function for functions that were previously using + assert_raises with SAWarning or similar. + + has some workarounds to accommodate the fact that the callable completes + with this approach rather than stopping at the exception raise. + + Also uses regex.search() to match the given message to the error string + rather than regex.match(). + + """ + with _expect_warnings( + except_cls, + [msg], + search_msg=True, + regex=False, + squelch_other_warnings=True, + ): + return callable_(*args, **kwargs) + + +def assert_raises_message_context_ok( + except_cls, msg, callable_, *args, **kwargs +): + return _assert_raises(except_cls, callable_, args, kwargs, msg=msg) + + +def _assert_raises( + except_cls, callable_, args, kwargs, msg=None, check_context=False +): + + with _expect_raises(except_cls, msg, check_context) as ec: + callable_(*args, **kwargs) + return ec.error + + +class _ErrorContainer(object): + error = None + + +@contextlib.contextmanager +def _expect_raises(except_cls, msg=None, check_context=False): + if ( + isinstance(except_cls, type) + and issubclass(except_cls, Warning) + or isinstance(except_cls, Warning) + ): + raise TypeError( + "Use expect_warnings for warnings, not " + "expect_raises / assert_raises" + ) + ec = _ErrorContainer() + if check_context: + are_we_already_in_a_traceback = sys.exc_info()[0] + try: + yield ec + success = False + except except_cls as err: + ec.error = err + success = True + if msg is not None: + assert re.search( + msg, util.text_type(err), re.UNICODE + ), "%r !~ %s" % (msg, err) + if check_context and not are_we_already_in_a_traceback: + _assert_proper_exception_context(err) + print(util.text_type(err).encode("utf-8")) + + # it's generally a good idea to not carry traceback objects outside + # of the except: block, but in this case especially we seem to have + # hit some bug in either python 3.10.0b2 or greenlet or both which + # this seems to fix: + # https://github.com/python-greenlet/greenlet/issues/242 + del ec + + # assert outside the block so it works for AssertionError too ! + assert success, "Callable did not raise an exception" + + +def expect_raises(except_cls, check_context=True): + return _expect_raises(except_cls, check_context=check_context) + + +def expect_raises_message(except_cls, msg, check_context=True): + return _expect_raises(except_cls, msg=msg, check_context=check_context) + + +class AssertsCompiledSQL(object): + def assert_compile( + self, + clause, + result, + params=None, + checkparams=None, + for_executemany=False, + check_literal_execute=None, + check_post_param=None, + dialect=None, + checkpositional=None, + check_prefetch=None, + use_default_dialect=False, + allow_dialect_select=False, + supports_default_values=True, + supports_default_metavalue=True, + literal_binds=False, + render_postcompile=False, + schema_translate_map=None, + render_schema_translate=False, + default_schema_name=None, + from_linting=False, + ): + if use_default_dialect: + dialect = default.DefaultDialect() + dialect.supports_default_values = supports_default_values + dialect.supports_default_metavalue = supports_default_metavalue + elif allow_dialect_select: + dialect = None + else: + if dialect is None: + dialect = getattr(self, "__dialect__", None) + + if dialect is None: + dialect = config.db.dialect + elif dialect == "default": + dialect = default.DefaultDialect() + dialect.supports_default_values = supports_default_values + dialect.supports_default_metavalue = supports_default_metavalue + elif dialect == "default_enhanced": + dialect = default.StrCompileDialect() + elif isinstance(dialect, util.string_types): + dialect = url.URL.create(dialect).get_dialect()() + + if default_schema_name: + dialect.default_schema_name = default_schema_name + + kw = {} + compile_kwargs = {} + + if schema_translate_map: + kw["schema_translate_map"] = schema_translate_map + + if params is not None: + kw["column_keys"] = list(params) + + if literal_binds: + compile_kwargs["literal_binds"] = True + + if render_postcompile: + compile_kwargs["render_postcompile"] = True + + if for_executemany: + kw["for_executemany"] = True + + if render_schema_translate: + kw["render_schema_translate"] = True + + if from_linting or getattr(self, "assert_from_linting", False): + kw["linting"] = sql.FROM_LINTING + + from sqlalchemy import orm + + if isinstance(clause, orm.Query): + stmt = clause._statement_20() + stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL + clause = stmt + + if compile_kwargs: + kw["compile_kwargs"] = compile_kwargs + + class DontAccess(object): + def __getattribute__(self, key): + raise NotImplementedError( + "compiler accessed .statement; use " + "compiler.current_executable" + ) + + class CheckCompilerAccess(object): + def __init__(self, test_statement): + self.test_statement = test_statement + self._annotations = {} + self.supports_execution = getattr( + test_statement, "supports_execution", False + ) + + if self.supports_execution: + self._execution_options = test_statement._execution_options + + if hasattr(test_statement, "_returning"): + self._returning = test_statement._returning + if hasattr(test_statement, "_inline"): + self._inline = test_statement._inline + if hasattr(test_statement, "_return_defaults"): + self._return_defaults = test_statement._return_defaults + + def _default_dialect(self): + return self.test_statement._default_dialect() + + def compile(self, dialect, **kw): + return self.test_statement.compile.__func__( + self, dialect=dialect, **kw + ) + + def _compiler(self, dialect, **kw): + return self.test_statement._compiler.__func__( + self, dialect, **kw + ) + + def _compiler_dispatch(self, compiler, **kwargs): + if hasattr(compiler, "statement"): + with mock.patch.object( + compiler, "statement", DontAccess() + ): + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + else: + return self.test_statement._compiler_dispatch( + compiler, **kwargs + ) + + # no construct can assume it's the "top level" construct in all cases + # as anything can be nested. ensure constructs don't assume they + # are the "self.statement" element + c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw) + + if isinstance(clause, sqltypes.TypeEngine): + cache_key_no_warnings = clause._static_cache_key + if cache_key_no_warnings: + hash(cache_key_no_warnings) + else: + cache_key_no_warnings = clause._generate_cache_key() + if cache_key_no_warnings: + hash(cache_key_no_warnings[0]) + + param_str = repr(getattr(c, "params", {})) + if util.py3k: + param_str = param_str.encode("utf-8").decode("ascii", "ignore") + print( + ("\nSQL String:\n" + util.text_type(c) + param_str).encode( + "utf-8" + ) + ) + else: + print( + "\nSQL String:\n" + + util.text_type(c).encode("utf-8") + + param_str + ) + + cc = re.sub(r"[\n\t]", "", util.text_type(c)) + + eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect)) + + if checkparams is not None: + eq_(c.construct_params(params), checkparams) + if checkpositional is not None: + p = c.construct_params(params) + eq_(tuple([p[x] for x in c.positiontup]), checkpositional) + if check_prefetch is not None: + eq_(c.prefetch, check_prefetch) + if check_literal_execute is not None: + eq_( + { + c.bind_names[b]: b.effective_value + for b in c.literal_execute_params + }, + check_literal_execute, + ) + if check_post_param is not None: + eq_( + { + c.bind_names[b]: b.effective_value + for b in c.post_compile_params + }, + check_post_param, + ) + + +class ComparesTables(object): + def assert_tables_equal(self, table, reflected_table, strict_types=False): + assert len(table.c) == len(reflected_table.c) + for c, reflected_c in zip(table.c, reflected_table.c): + eq_(c.name, reflected_c.name) + assert reflected_c is reflected_table.c[c.name] + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) + + if strict_types: + msg = "Type '%s' doesn't correspond to type '%s'" + assert isinstance(reflected_c.type, type(c.type)), msg % ( + reflected_c.type, + c.type, + ) + else: + self.assert_types_base(reflected_c, c) + + if isinstance(c.type, sqltypes.String): + eq_(c.type.length, reflected_c.type.length) + + eq_( + {f.column.name for f in c.foreign_keys}, + {f.column.name for f in reflected_c.foreign_keys}, + ) + if c.server_default: + assert isinstance( + reflected_c.server_default, schema.FetchedValue + ) + + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None + + def assert_types_base(self, c1, c2): + assert c1.type._compare_type_affinity( + c2.type + ), "On column %r, type '%s' doesn't correspond to type '%s'" % ( + c1.name, + c1.type, + c2.type, + ) + + +class AssertsExecutionResults(object): + def assert_result(self, result, class_, *objects): + result = list(result) + print(repr(result)) + self.assert_list(result, class_, objects) + + def assert_list(self, result, class_, list_): + self.assert_( + len(result) == len(list_), + "result list is not the same size as test list, " + + "for class " + + class_.__name__, + ) + for i in range(0, len(list_)): + self.assert_row(class_, result[i], list_[i]) + + def assert_row(self, class_, rowobj, desc): + self.assert_( + rowobj.__class__ is class_, "item class is not " + repr(class_) + ) + for key, value in desc.items(): + if isinstance(value, tuple): + if isinstance(value[1], list): + self.assert_list(getattr(rowobj, key), value[0], value[1]) + else: + self.assert_row(value[0], getattr(rowobj, key), value[1]) + else: + self.assert_( + getattr(rowobj, key) == value, + "attribute %s value %s does not match %s" + % (key, getattr(rowobj, key), value), + ) + + def assert_unordered_result(self, result, cls, *expected): + """As assert_result, but the order of objects is not considered. + + The algorithm is very expensive but not a big deal for the small + numbers of rows that the test suite manipulates. + """ + + class immutabledict(dict): + def __hash__(self): + return id(self) + + found = util.IdentitySet(result) + expected = {immutabledict(e) for e in expected} + + for wrong in util.itertools_filterfalse( + lambda o: isinstance(o, cls), found + ): + fail( + 'Unexpected type "%s", expected "%s"' + % (type(wrong).__name__, cls.__name__) + ) + + if len(found) != len(expected): + fail( + 'Unexpected object count "%s", expected "%s"' + % (len(found), len(expected)) + ) + + NOVALUE = object() + + def _compare_item(obj, spec): + for key, value in spec.items(): + if isinstance(value, tuple): + try: + self.assert_unordered_result( + getattr(obj, key), value[0], *value[1] + ) + except AssertionError: + return False + else: + if getattr(obj, key, NOVALUE) != value: + return False + return True + + for expected_item in expected: + for found_item in found: + if _compare_item(found_item, expected_item): + found.remove(found_item) + break + else: + fail( + "Expected %s instance with attributes %s not found." + % (cls.__name__, repr(expected_item)) + ) + return True + + def sql_execution_asserter(self, db=None): + if db is None: + from . import db as db + + return assertsql.assert_engine(db) + + def assert_sql_execution(self, db, callable_, *rules): + with self.sql_execution_asserter(db) as asserter: + result = callable_() + asserter.assert_(*rules) + return result + + def assert_sql(self, db, callable_, rules): + + newrules = [] + for rule in rules: + if isinstance(rule, dict): + newrule = assertsql.AllOf( + *[assertsql.CompiledSQL(k, v) for k, v in rule.items()] + ) + else: + newrule = assertsql.CompiledSQL(*rule) + newrules.append(newrule) + + return self.assert_sql_execution(db, callable_, *newrules) + + def assert_sql_count(self, db, callable_, count): + self.assert_sql_execution( + db, callable_, assertsql.CountStatements(count) + ) + + def assert_multiple_sql_count(self, dbs, callable_, counts): + recs = [ + (self.sql_execution_asserter(db), db, count) + for (db, count) in zip(dbs, counts) + ] + asserters = [] + for ctx, db, count in recs: + asserters.append(ctx.__enter__()) + try: + return callable_() + finally: + for asserter, (ctx, db, count) in zip(asserters, recs): + ctx.__exit__(None, None, None) + asserter.assert_(assertsql.CountStatements(count)) + + @contextlib.contextmanager + def assert_execution(self, db, *rules): + with self.sql_execution_asserter(db) as asserter: + yield + asserter.assert_(*rules) + + def assert_statement_count(self, db, count): + return self.assert_execution(db, assertsql.CountStatements(count)) diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py new file mode 100644 index 0000000..565b3ed --- /dev/null +++ b/lib/sqlalchemy/testing/assertsql.py @@ -0,0 +1,457 @@ +# testing/assertsql.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import collections +import contextlib +import re + +from .. import event +from .. import util +from ..engine import url +from ..engine.default import DefaultDialect +from ..engine.util import _distill_cursor_params +from ..schema import _DDLCompiles + + +class AssertRule(object): + + is_consumed = False + errormessage = None + consume_statement = True + + def process_statement(self, execute_observed): + pass + + def no_more_statements(self): + assert False, ( + "All statements are complete, but pending " + "assertion rules remain" + ) + + +class SQLMatchRule(AssertRule): + pass + + +class CursorSQL(SQLMatchRule): + def __init__(self, statement, params=None, consume_statement=True): + self.statement = statement + self.params = params + self.consume_statement = consume_statement + + def process_statement(self, execute_observed): + stmt = execute_observed.statements[0] + if self.statement != stmt.statement or ( + self.params is not None and self.params != stmt.parameters + ): + self.errormessage = ( + "Testing for exact SQL %s parameters %s received %s %s" + % ( + self.statement, + self.params, + stmt.statement, + stmt.parameters, + ) + ) + else: + execute_observed.statements.pop(0) + self.is_consumed = True + if not execute_observed.statements: + self.consume_statement = True + + +class CompiledSQL(SQLMatchRule): + def __init__(self, statement, params=None, dialect="default"): + self.statement = statement + self.params = params + self.dialect = dialect + + def _compare_sql(self, execute_observed, received_statement): + stmt = re.sub(r"[\n\t]", "", self.statement) + return received_statement == stmt + + def _compile_dialect(self, execute_observed): + if self.dialect == "default": + dialect = DefaultDialect() + # this is currently what tests are expecting + # dialect.supports_default_values = True + dialect.supports_default_metavalue = True + return dialect + else: + # ugh + if self.dialect == "postgresql": + params = {"implicit_returning": True} + else: + params = {} + return url.URL.create(self.dialect).get_dialect()(**params) + + def _received_statement(self, execute_observed): + """reconstruct the statement and params in terms + of a target dialect, which for CompiledSQL is just DefaultDialect.""" + + context = execute_observed.context + compare_dialect = self._compile_dialect(execute_observed) + + # received_statement runs a full compile(). we should not need to + # consider extracted_parameters; if we do this indicates some state + # is being sent from a previous cached query, which some misbehaviors + # in the ORM can cause, see #6881 + cache_key = None # execute_observed.context.compiled.cache_key + extracted_parameters = ( + None # execute_observed.context.extracted_parameters + ) + + if "schema_translate_map" in context.execution_options: + map_ = context.execution_options["schema_translate_map"] + else: + map_ = None + + if isinstance(execute_observed.clauseelement, _DDLCompiles): + + compiled = execute_observed.clauseelement.compile( + dialect=compare_dialect, + schema_translate_map=map_, + ) + else: + compiled = execute_observed.clauseelement.compile( + cache_key=cache_key, + dialect=compare_dialect, + column_keys=context.compiled.column_keys, + for_executemany=context.compiled.for_executemany, + schema_translate_map=map_, + ) + _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled)) + parameters = execute_observed.parameters + + if not parameters: + _received_parameters = [ + compiled.construct_params( + extracted_parameters=extracted_parameters + ) + ] + else: + _received_parameters = [ + compiled.construct_params( + m, extracted_parameters=extracted_parameters + ) + for m in parameters + ] + + return _received_statement, _received_parameters + + def process_statement(self, execute_observed): + context = execute_observed.context + + _received_statement, _received_parameters = self._received_statement( + execute_observed + ) + params = self._all_params(context) + + equivalent = self._compare_sql(execute_observed, _received_statement) + + if equivalent: + if params is not None: + all_params = list(params) + all_received = list(_received_parameters) + while all_params and all_received: + param = dict(all_params.pop(0)) + + for idx, received in enumerate(list(all_received)): + # do a positive compare only + for param_key in param: + # a key in param did not match current + # 'received' + if ( + param_key not in received + or received[param_key] != param[param_key] + ): + break + else: + # all keys in param matched 'received'; + # onto next param + del all_received[idx] + break + else: + # param did not match any entry + # in all_received + equivalent = False + break + if all_params or all_received: + equivalent = False + + if equivalent: + self.is_consumed = True + self.errormessage = None + else: + self.errormessage = self._failure_message(params) % { + "received_statement": _received_statement, + "received_parameters": _received_parameters, + } + + def _all_params(self, context): + if self.params: + if callable(self.params): + params = self.params(context) + else: + params = self.params + if not isinstance(params, list): + params = [params] + return params + else: + return None + + def _failure_message(self, expected_params): + return ( + "Testing for compiled statement\n%r partial params %s, " + "received\n%%(received_statement)r with params " + "%%(received_parameters)r" + % ( + self.statement.replace("%", "%%"), + repr(expected_params).replace("%", "%%"), + ) + ) + + +class RegexSQL(CompiledSQL): + def __init__(self, regex, params=None, dialect="default"): + SQLMatchRule.__init__(self) + self.regex = re.compile(regex) + self.orig_regex = regex + self.params = params + self.dialect = dialect + + def _failure_message(self, expected_params): + return ( + "Testing for compiled statement ~%r partial params %s, " + "received %%(received_statement)r with params " + "%%(received_parameters)r" + % ( + self.orig_regex.replace("%", "%%"), + repr(expected_params).replace("%", "%%"), + ) + ) + + def _compare_sql(self, execute_observed, received_statement): + return bool(self.regex.match(received_statement)) + + +class DialectSQL(CompiledSQL): + def _compile_dialect(self, execute_observed): + return execute_observed.context.dialect + + def _compare_no_space(self, real_stmt, received_stmt): + stmt = re.sub(r"[\n\t]", "", real_stmt) + return received_stmt == stmt + + def _received_statement(self, execute_observed): + received_stmt, received_params = super( + DialectSQL, self + )._received_statement(execute_observed) + + # TODO: why do we need this part? + for real_stmt in execute_observed.statements: + if self._compare_no_space(real_stmt.statement, received_stmt): + break + else: + raise AssertionError( + "Can't locate compiled statement %r in list of " + "statements actually invoked" % received_stmt + ) + + return received_stmt, execute_observed.context.compiled_parameters + + def _compare_sql(self, execute_observed, received_statement): + stmt = re.sub(r"[\n\t]", "", self.statement) + # convert our comparison statement to have the + # paramstyle of the received + paramstyle = execute_observed.context.dialect.paramstyle + if paramstyle == "pyformat": + stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt) + else: + # positional params + repl = None + if paramstyle == "qmark": + repl = "?" + elif paramstyle == "format": + repl = r"%s" + elif paramstyle == "numeric": + repl = None + stmt = re.sub(r":([\w_]+)", repl, stmt) + + return received_statement == stmt + + +class CountStatements(AssertRule): + def __init__(self, count): + self.count = count + self._statement_count = 0 + + def process_statement(self, execute_observed): + self._statement_count += 1 + + def no_more_statements(self): + if self.count != self._statement_count: + assert False, "desired statement count %d does not match %d" % ( + self.count, + self._statement_count, + ) + + +class AllOf(AssertRule): + def __init__(self, *rules): + self.rules = set(rules) + + def process_statement(self, execute_observed): + for rule in list(self.rules): + rule.errormessage = None + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.discard(rule) + if not self.rules: + self.is_consumed = True + break + elif not rule.errormessage: + # rule is not done yet + self.errormessage = None + break + else: + self.errormessage = list(self.rules)[0].errormessage + + +class EachOf(AssertRule): + def __init__(self, *rules): + self.rules = list(rules) + + def process_statement(self, execute_observed): + while self.rules: + rule = self.rules[0] + rule.process_statement(execute_observed) + if rule.is_consumed: + self.rules.pop(0) + elif rule.errormessage: + self.errormessage = rule.errormessage + if rule.consume_statement: + break + + if not self.rules: + self.is_consumed = True + + def no_more_statements(self): + if self.rules and not self.rules[0].is_consumed: + self.rules[0].no_more_statements() + elif self.rules: + super(EachOf, self).no_more_statements() + + +class Conditional(EachOf): + def __init__(self, condition, rules, else_rules): + if condition: + super(Conditional, self).__init__(*rules) + else: + super(Conditional, self).__init__(*else_rules) + + +class Or(AllOf): + def process_statement(self, execute_observed): + for rule in self.rules: + rule.process_statement(execute_observed) + if rule.is_consumed: + self.is_consumed = True + break + else: + self.errormessage = list(self.rules)[0].errormessage + + +class SQLExecuteObserved(object): + def __init__(self, context, clauseelement, multiparams, params): + self.context = context + self.clauseelement = clauseelement + self.parameters = _distill_cursor_params( + context.connection, tuple(multiparams), params + ) + self.statements = [] + + def __repr__(self): + return str(self.statements) + + +class SQLCursorExecuteObserved( + collections.namedtuple( + "SQLCursorExecuteObserved", + ["statement", "parameters", "context", "executemany"], + ) +): + pass + + +class SQLAsserter(object): + def __init__(self): + self.accumulated = [] + + def _close(self): + self._final = self.accumulated + del self.accumulated + + def assert_(self, *rules): + rule = EachOf(*rules) + + observed = list(self._final) + while observed: + statement = observed.pop(0) + rule.process_statement(statement) + if rule.is_consumed: + break + elif rule.errormessage: + assert False, rule.errormessage + if observed: + assert False, "Additional SQL statements remain:\n%s" % observed + elif not rule.is_consumed: + rule.no_more_statements() + + +@contextlib.contextmanager +def assert_engine(engine): + asserter = SQLAsserter() + + orig = [] + + @event.listens_for(engine, "before_execute") + def connection_execute( + conn, clauseelement, multiparams, params, execution_options + ): + # grab the original statement + params before any cursor + # execution + orig[:] = clauseelement, multiparams, params + + @event.listens_for(engine, "after_cursor_execute") + def cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + if not context: + return + # then grab real cursor statements and associate them all + # around a single context + if ( + asserter.accumulated + and asserter.accumulated[-1].context is context + ): + obs = asserter.accumulated[-1] + else: + obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2]) + asserter.accumulated.append(obs) + obs.statements.append( + SQLCursorExecuteObserved( + statement, parameters, context, executemany + ) + ) + + try: + yield asserter + finally: + event.remove(engine, "after_cursor_execute", cursor_execute) + event.remove(engine, "before_execute", connection_execute) + asserter._close() diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py new file mode 100644 index 0000000..2189060 --- /dev/null +++ b/lib/sqlalchemy/testing/asyncio.py @@ -0,0 +1,128 @@ +# testing/asyncio.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +# functions and wrappers to run tests, fixtures, provisioning and +# setup/teardown in an asyncio event loop, conditionally based on the +# current DB driver being used for a test. + +# note that SQLAlchemy's asyncio integration also supports a method +# of running individual asyncio functions inside of separate event loops +# using "async_fallback" mode; however running whole functions in the event +# loop is a more accurate test for how SQLAlchemy's asyncio features +# would run in the real world. + + +from functools import wraps +import inspect + +from . import config +from ..util.concurrency import _util_async_run +from ..util.concurrency import _util_async_run_coroutine_function + +# may be set to False if the +# --disable-asyncio flag is passed to the test runner. +ENABLE_ASYNCIO = True + + +def _run_coroutine_function(fn, *args, **kwargs): + return _util_async_run_coroutine_function(fn, *args, **kwargs) + + +def _assume_async(fn, *args, **kwargs): + """Run a function in an asyncio loop unconditionally. + + This function is used for provisioning features like + testing a database connection for server info. + + Note that for blocking IO database drivers, this means they block the + event loop. + + """ + + if not ENABLE_ASYNCIO: + return fn(*args, **kwargs) + + return _util_async_run(fn, *args, **kwargs) + + +def _maybe_async_provisioning(fn, *args, **kwargs): + """Run a function in an asyncio loop if any current drivers might need it. + + This function is used for provisioning features that take + place outside of a specific database driver being selected, so if the + current driver that happens to be used for the provisioning operation + is an async driver, it will run in asyncio and not fail. + + Note that for blocking IO database drivers, this means they block the + event loop. + + """ + if not ENABLE_ASYNCIO: + return fn(*args, **kwargs) + + if config.any_async: + return _util_async_run(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + +def _maybe_async(fn, *args, **kwargs): + """Run a function in an asyncio loop if the current selected driver is + async. + + This function is used for test setup/teardown and tests themselves + where the current DB driver is known. + + + """ + if not ENABLE_ASYNCIO: + + return fn(*args, **kwargs) + + is_async = config._current.is_async + + if is_async: + return _util_async_run(fn, *args, **kwargs) + else: + return fn(*args, **kwargs) + + +def _maybe_async_wrapper(fn): + """Apply the _maybe_async function to an existing function and return + as a wrapped callable, supporting generator functions as well. + + This is currently used for pytest fixtures that support generator use. + + """ + + if inspect.isgeneratorfunction(fn): + _stop = object() + + def call_next(gen): + try: + return next(gen) + # can't raise StopIteration in an awaitable. + except StopIteration: + return _stop + + @wraps(fn) + def wrap_fixture(*args, **kwargs): + gen = fn(*args, **kwargs) + while True: + value = _maybe_async(call_next, gen) + if value is _stop: + break + yield value + + else: + + @wraps(fn) + def wrap_fixture(*args, **kwargs): + return _maybe_async(fn, *args, **kwargs) + + return wrap_fixture diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py new file mode 100644 index 0000000..fc13a16 --- /dev/null +++ b/lib/sqlalchemy/testing/config.py @@ -0,0 +1,209 @@ +# testing/config.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import collections + +from .. import util + +requirements = None +db = None +db_url = None +db_opts = None +file_config = None +test_schema = None +test_schema_2 = None +any_async = False +_current = None +ident = "main" + +_fixture_functions = None # installed by plugin_base + + +def combinations(*comb, **kw): + r"""Deliver multiple versions of a test based on positional combinations. + + This is a facade over pytest.mark.parametrize. + + + :param \*comb: argument combinations. These are tuples that will be passed + positionally to the decorated function. + + :param argnames: optional list of argument names. These are the names + of the arguments in the test function that correspond to the entries + in each argument tuple. pytest.mark.parametrize requires this, however + the combinations function will derive it automatically if not present + by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the + first argument is "self" which is discarded. + + :param id\_: optional id template. This is a string template that + describes how the "id" for each parameter set should be defined, if any. + The number of characters in the template should match the number of + entries in each argument tuple. Each character describes how the + corresponding entry in the argument tuple should be handled, as far as + whether or not it is included in the arguments passed to the function, as + well as if it is included in the tokens used to create the id of the + parameter set. + + If omitted, the argument combinations are passed to parametrize as is. If + passed, each argument combination is turned into a pytest.param() object, + mapping the elements of the argument tuple to produce an id based on a + character value in the same position within the string template using the + following scheme:: + + i - the given argument is a string that is part of the id only, don't + pass it as an argument + + n - the given argument should be passed and it should be added to the + id by calling the .__name__ attribute + + r - the given argument should be passed and it should be added to the + id by calling repr() + + s - the given argument should be passed and it should be added to the + id by calling str() + + a - (argument) the given argument should be passed and it should not + be used to generated the id + + e.g.:: + + @testing.combinations( + (operator.eq, "eq"), + (operator.ne, "ne"), + (operator.gt, "gt"), + (operator.lt, "lt"), + id_="na" + ) + def test_operator(self, opfunc, name): + pass + + The above combination will call ``.__name__`` on the first member of + each tuple and use that as the "id" to pytest.param(). + + + """ + return _fixture_functions.combinations(*comb, **kw) + + +def combinations_list(arg_iterable, **kw): + "As combination, but takes a single iterable" + return combinations(*arg_iterable, **kw) + + +def fixture(*arg, **kw): + return _fixture_functions.fixture(*arg, **kw) + + +def get_current_test_name(): + return _fixture_functions.get_current_test_name() + + +def mark_base_test_class(): + return _fixture_functions.mark_base_test_class() + + +class Config(object): + def __init__(self, db, db_opts, options, file_config): + self._set_name(db) + self.db = db + self.db_opts = db_opts + self.options = options + self.file_config = file_config + self.test_schema = "test_schema" + self.test_schema_2 = "test_schema_2" + + self.is_async = db.dialect.is_async and not util.asbool( + db.url.query.get("async_fallback", False) + ) + + _stack = collections.deque() + _configs = set() + + def _set_name(self, db): + if db.dialect.server_version_info: + svi = ".".join(str(tok) for tok in db.dialect.server_version_info) + self.name = "%s+%s_[%s]" % (db.name, db.driver, svi) + else: + self.name = "%s+%s" % (db.name, db.driver) + + @classmethod + def register(cls, db, db_opts, options, file_config): + """add a config as one of the global configs. + + If there are no configs set up yet, this config also + gets set as the "_current". + """ + global any_async + + cfg = Config(db, db_opts, options, file_config) + + # if any backends include an async driver, then ensure + # all setup/teardown and tests are wrapped in the maybe_async() + # decorator that will set up a greenlet context for async drivers. + any_async = any_async or cfg.is_async + + cls._configs.add(cfg) + return cfg + + @classmethod + def set_as_current(cls, config, namespace): + global db, _current, db_url, test_schema, test_schema_2, db_opts + _current = config + db_url = config.db.url + db_opts = config.db_opts + test_schema = config.test_schema + test_schema_2 = config.test_schema_2 + namespace.db = db = config.db + + @classmethod + def push_engine(cls, db, namespace): + assert _current, "Can't push without a default Config set up" + cls.push( + Config( + db, _current.db_opts, _current.options, _current.file_config + ), + namespace, + ) + + @classmethod + def push(cls, config, namespace): + cls._stack.append(_current) + cls.set_as_current(config, namespace) + + @classmethod + def pop(cls, namespace): + if cls._stack: + # a failed test w/ -x option can call reset() ahead of time + _current = cls._stack[-1] + del cls._stack[-1] + cls.set_as_current(_current, namespace) + + @classmethod + def reset(cls, namespace): + if cls._stack: + cls.set_as_current(cls._stack[0], namespace) + cls._stack.clear() + + @classmethod + def all_configs(cls): + return cls._configs + + @classmethod + def all_dbs(cls): + for cfg in cls.all_configs(): + yield cfg.db + + def skip_test(self, msg): + skip_test(msg) + + +def skip_test(msg): + raise _fixture_functions.skip_test_exception(msg) + + +def async_test(fn): + return _fixture_functions.async_test(fn) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py new file mode 100644 index 0000000..b8be6b9 --- /dev/null +++ b/lib/sqlalchemy/testing/engines.py @@ -0,0 +1,465 @@ +# testing/engines.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +import collections +import re +import warnings +import weakref + +from . import config +from .util import decorator +from .util import gc_collect +from .. import event +from .. import pool +from ..util import await_only + + +class ConnectionKiller(object): + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + self.testing_engines = collections.defaultdict(set) + self.dbapi_connections = set() + + def add_pool(self, pool): + event.listen(pool, "checkout", self._add_conn) + event.listen(pool, "checkin", self._remove_conn) + event.listen(pool, "close", self._remove_conn) + event.listen(pool, "close_detached", self._remove_conn) + # note we are keeping "invalidated" here, as those are still + # opened connections we would like to roll back + + def _add_conn(self, dbapi_con, con_record, con_proxy): + self.dbapi_connections.add(dbapi_con) + self.proxy_refs[con_proxy] = True + + def _remove_conn(self, dbapi_conn, *arg): + self.dbapi_connections.discard(dbapi_conn) + + def add_engine(self, engine, scope): + self.add_pool(engine.pool) + + assert scope in ("class", "global", "function", "fixture") + self.testing_engines[scope].add(engine) + + def _safe(self, fn): + try: + fn() + except Exception as e: + warnings.warn( + "testing_reaper couldn't rollback/close connection: %s" % e + ) + + def rollback_all(self): + for rec in list(self.proxy_refs): + if rec is not None and rec.is_valid: + self._safe(rec.rollback) + + def checkin_all(self): + # run pool.checkin() for all ConnectionFairy instances we have + # tracked. + + for rec in list(self.proxy_refs): + if rec is not None and rec.is_valid: + self.dbapi_connections.discard(rec.dbapi_connection) + self._safe(rec._checkin) + + # for fairy refs that were GCed and could not close the connection, + # such as asyncio, roll back those remaining connections + for con in self.dbapi_connections: + self._safe(con.rollback) + self.dbapi_connections.clear() + + def close_all(self): + self.checkin_all() + + def prepare_for_drop_tables(self, connection): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + from . import provision + + provision.prepare_for_drop_tables(connection.engine.url, connection) + + def _drop_testing_engines(self, scope): + eng = self.testing_engines[scope] + for rec in list(eng): + for proxy_ref in list(self.proxy_refs): + if proxy_ref is not None and proxy_ref.is_valid: + if ( + proxy_ref._pool is not None + and proxy_ref._pool is rec.pool + ): + self._safe(proxy_ref._checkin) + if hasattr(rec, "sync_engine"): + await_only(rec.dispose()) + else: + rec.dispose() + eng.clear() + + def after_test(self): + self._drop_testing_engines("function") + + def after_test_outside_fixtures(self, test): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + if test.__class__.__leave_connections_for_teardown__: + return + + self.checkin_all() + + # on PostgreSQL, this will test for any "idle in transaction" + # connections. useful to identify tests with unusual patterns + # that can't be cleaned up correctly. + from . import provision + + with config.db.connect() as conn: + provision.prepare_for_drop_tables(conn.engine.url, conn) + + def stop_test_class_inside_fixtures(self): + self.checkin_all() + self._drop_testing_engines("function") + self._drop_testing_engines("class") + + def stop_test_class_outside_fixtures(self): + # ensure no refs to checked out connections at all. + + if pool.base._strong_ref_connection_records: + gc_collect() + + if pool.base._strong_ref_connection_records: + ln = len(pool.base._strong_ref_connection_records) + pool.base._strong_ref_connection_records.clear() + assert ( + False + ), "%d connection recs not cleared after test suite" % (ln) + + def final_cleanup(self): + self.checkin_all() + for scope in self.testing_engines: + self._drop_testing_engines(scope) + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + + +testing_reaper = ConnectionKiller() + + +@decorator +def assert_conns_closed(fn, *args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + + +@decorator +def rollback_open_connections(fn, *args, **kw): + """Decorator that rolls back all open connections after fn execution.""" + + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + + +@decorator +def close_first(fn, *args, **kw): + """Decorator that closes all connections before fn execution.""" + + testing_reaper.checkin_all() + fn(*args, **kw) + + +@decorator +def close_open_connections(fn, *args, **kw): + """Decorator that closes all connections after fn execution.""" + try: + fn(*args, **kw) + finally: + testing_reaper.checkin_all() + + +def all_dialects(exclude=None): + import sqlalchemy.dialects as d + + for name in d.__all__: + # TEMPORARY + if exclude and name in exclude: + continue + mod = getattr(d, name, None) + if not mod: + mod = getattr( + __import__("sqlalchemy.dialects.%s" % name).dialects, name + ) + yield mod.dialect() + + +class ReconnectFixture(object): + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + self.is_stopped = False + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + + conn = self.dbapi.connect(*args, **kwargs) + if self.is_stopped: + self._safe(conn.close) + curs = conn.cursor() # should fail on Oracle etc. + # should fail for everything that didn't fail + # above, connection is closed + curs.execute("select 1") + assert False, "simulated connect failure didn't work" + else: + self.connections.append(conn) + return conn + + def _safe(self, fn): + try: + fn() + except Exception as e: + warnings.warn("ReconnectFixture couldn't close connection: %s" % e) + + def shutdown(self, stop=False): + # TODO: this doesn't cover all cases + # as nicely as we'd like, namely MySQLdb. + # would need to implement R. Brewer's + # proxy server idea to get better + # coverage. + self.is_stopped = stop + for c in list(self.connections): + self._safe(c.close) + self.connections = [] + + def restart(self): + self.is_stopped = False + + +def reconnecting_engine(url=None, options=None): + url = url or config.db.url + dbapi = config.db.dialect.dbapi + if not options: + options = {} + options["module"] = ReconnectFixture(dbapi) + engine = testing_engine(url, options) + _dispose = engine.dispose + + def dispose(): + engine.dialect.dbapi.shutdown() + engine.dialect.dbapi.is_stopped = False + _dispose() + + engine.test_shutdown = engine.dialect.dbapi.shutdown + engine.test_restart = engine.dialect.dbapi.restart + engine.dispose = dispose + return engine + + +def testing_engine( + url=None, + options=None, + future=None, + asyncio=False, + transfer_staticpool=False, + _sqlite_savepoint=False, +): + """Produce an engine configured by --options with optional overrides.""" + + if asyncio: + assert not _sqlite_savepoint + from sqlalchemy.ext.asyncio import ( + create_async_engine as create_engine, + ) + elif future or ( + config.db and config.db._is_future and future is not False + ): + from sqlalchemy.future import create_engine + else: + from sqlalchemy import create_engine + from sqlalchemy.engine.url import make_url + + if not options: + use_reaper = True + scope = "function" + sqlite_savepoint = False + else: + use_reaper = options.pop("use_reaper", True) + scope = options.pop("scope", "function") + sqlite_savepoint = options.pop("sqlite_savepoint", False) + + url = url or config.db.url + + url = make_url(url) + if options is None: + if config.db is None or url.drivername == config.db.url.drivername: + options = config.db_opts + else: + options = {} + elif config.db is not None and url.drivername == config.db.url.drivername: + default_opt = config.db_opts.copy() + default_opt.update(options) + + engine = create_engine(url, **options) + + if sqlite_savepoint and engine.name == "sqlite": + # apply SQLite savepoint workaround + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN") + + if transfer_staticpool: + from sqlalchemy.pool import StaticPool + + if config.db is not None and isinstance(config.db.pool, StaticPool): + use_reaper = False + engine.pool._transfer_from(config.db.pool) + + if scope == "global": + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = ( + True # enable event blocks, helps with profiling + ) + + if isinstance(engine.pool, pool.QueuePool): + engine.pool._timeout = 0 + engine.pool._max_overflow = 0 + if use_reaper: + testing_reaper.add_engine(engine, scope) + + return engine + + +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ + + from sqlalchemy import create_mock_engine + + if not dialect_name: + dialect_name = config.db.name + + buffer = [] + + def executor(sql, *a, **kw): + buffer.append(sql) + + def assert_sql(stmts): + recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer] + assert recv == stmts, recv + + def print_sql(): + d = engine.dialect + return "\n".join(str(s.compile(dialect=d)) for s in engine.mock) + + engine = create_mock_engine(dialect_name + "://", executor) + assert not hasattr(engine, "mock") + engine.mock = buffer + engine.assert_sql = assert_sql + engine.print_sql = print_sql + return engine + + +class DBAPIProxyCursor(object): + """Proxy a DBAPI cursor. + + Tests can provide subclasses of this to intercept + DBAPI-level cursor operations. + + """ + + def __init__(self, engine, conn, *args, **kwargs): + self.engine = engine + self.connection = conn + self.cursor = conn.cursor(*args, **kwargs) + + def execute(self, stmt, parameters=None, **kw): + if parameters: + return self.cursor.execute(stmt, parameters, **kw) + else: + return self.cursor.execute(stmt, **kw) + + def executemany(self, stmt, params, **kw): + return self.cursor.executemany(stmt, params, **kw) + + def __iter__(self): + return iter(self.cursor) + + def __getattr__(self, key): + return getattr(self.cursor, key) + + +class DBAPIProxyConnection(object): + """Proxy a DBAPI connection. + + Tests can provide subclasses of this to intercept + DBAPI-level connection operations. + + """ + + def __init__(self, engine, cursor_cls): + self.conn = engine.pool._creator() + self.engine = engine + self.cursor_cls = cursor_cls + + def cursor(self, *args, **kwargs): + return self.cursor_cls(self.engine, self.conn, *args, **kwargs) + + def close(self): + self.conn.close() + + def __getattr__(self, key): + return getattr(self.conn, key) + + +def proxying_engine( + conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor +): + """Produce an engine that provides proxy hooks for + common methods. + + """ + + def mock_conn(): + return conn_cls(config.db, cursor_cls) + + def _wrap_do_on_connect(do_on_connect): + def go(dbapi_conn): + return do_on_connect(dbapi_conn.conn) + + return go + + return testing_engine( + options={ + "creator": mock_conn, + "_wrap_do_on_connect": _wrap_do_on_connect, + } + ) diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py new file mode 100644 index 0000000..8ea65d6 --- /dev/null +++ b/lib/sqlalchemy/testing/entities.py @@ -0,0 +1,111 @@ +# testing/entities.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import sqlalchemy as sa +from .. import exc as sa_exc +from ..util import compat + +_repr_stack = set() + + +class BasicEntity(object): + def __init__(self, **kw): + for key, value in kw.items(): + setattr(self, key, value) + + def __repr__(self): + if id(self) in _repr_stack: + return object.__repr__(self) + _repr_stack.add(id(self)) + try: + return "%s(%s)" % ( + (self.__class__.__name__), + ", ".join( + [ + "%s=%r" % (key, getattr(self, key)) + for key in sorted(self.__dict__.keys()) + if not key.startswith("_") + ] + ), + ) + finally: + _repr_stack.remove(id(self)) + + +_recursion_stack = set() + + +class ComparableMixin(object): + def __ne__(self, other): + return not self.__eq__(other) + + def __eq__(self, other): + """'Deep, sparse compare. + + Deeply compare two entities, following the non-None attributes of the + non-persisted object, if possible. + + """ + if other is self: + return True + elif not self.__class__ == other.__class__: + return False + + if id(self) in _recursion_stack: + return True + _recursion_stack.add(id(self)) + + try: + # pick the entity that's not SA persisted as the source + try: + self_key = sa.orm.attributes.instance_state(self).key + except sa.orm.exc.NO_STATE: + self_key = None + + if other is None: + a = self + b = other + elif self_key is not None: + a = other + b = self + else: + a = self + b = other + + for attr in list(a.__dict__): + if attr.startswith("_"): + continue + value = getattr(a, attr) + + try: + # handle lazy loader errors + battr = getattr(b, attr) + except (AttributeError, sa_exc.UnboundExecutionError): + return False + + if hasattr(value, "__iter__") and not isinstance( + value, compat.string_types + ): + if hasattr(value, "__getitem__") and not hasattr( + value, "keys" + ): + if list(value) != list(battr): + return False + else: + if set(value) != set(battr): + return False + else: + if value is not None and value != battr: + return False + return True + finally: + _recursion_stack.remove(id(self)) + + +class ComparableEntity(ComparableMixin, BasicEntity): + def __hash__(self): + return hash(self.__class__) diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py new file mode 100644 index 0000000..521a4aa --- /dev/null +++ b/lib/sqlalchemy/testing/exclusions.py @@ -0,0 +1,465 @@ +# testing/exclusions.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +import contextlib +import operator +import re +import sys + +from . import config +from .. import util +from ..util import decorator +from ..util.compat import inspect_getfullargspec + + +def skip_if(predicate, reason=None): + rule = compound() + pred = _as_predicate(predicate, reason) + rule.skips.add(pred) + return rule + + +def fails_if(predicate, reason=None): + rule = compound() + pred = _as_predicate(predicate, reason) + rule.fails.add(pred) + return rule + + +class compound(object): + def __init__(self): + self.fails = set() + self.skips = set() + self.tags = set() + + def __add__(self, other): + return self.add(other) + + def as_skips(self): + rule = compound() + rule.skips.update(self.skips) + rule.skips.update(self.fails) + rule.tags.update(self.tags) + return rule + + def add(self, *others): + copy = compound() + copy.fails.update(self.fails) + copy.skips.update(self.skips) + copy.tags.update(self.tags) + for other in others: + copy.fails.update(other.fails) + copy.skips.update(other.skips) + copy.tags.update(other.tags) + return copy + + def not_(self): + copy = compound() + copy.fails.update(NotPredicate(fail) for fail in self.fails) + copy.skips.update(NotPredicate(skip) for skip in self.skips) + copy.tags.update(self.tags) + return copy + + @property + def enabled(self): + return self.enabled_for_config(config._current) + + def enabled_for_config(self, config): + for predicate in self.skips.union(self.fails): + if predicate(config): + return False + else: + return True + + def matching_config_reasons(self, config): + return [ + predicate._as_string(config) + for predicate in self.skips.union(self.fails) + if predicate(config) + ] + + def include_test(self, include_tags, exclude_tags): + return bool( + not self.tags.intersection(exclude_tags) + and (not include_tags or self.tags.intersection(include_tags)) + ) + + def _extend(self, other): + self.skips.update(other.skips) + self.fails.update(other.fails) + self.tags.update(other.tags) + + def __call__(self, fn): + if hasattr(fn, "_sa_exclusion_extend"): + fn._sa_exclusion_extend._extend(self) + return fn + + @decorator + def decorate(fn, *args, **kw): + return self._do(config._current, fn, *args, **kw) + + decorated = decorate(fn) + decorated._sa_exclusion_extend = self + return decorated + + @contextlib.contextmanager + def fail_if(self): + all_fails = compound() + all_fails.fails.update(self.skips.union(self.fails)) + + try: + yield + except Exception as ex: + all_fails._expect_failure(config._current, ex) + else: + all_fails._expect_success(config._current) + + def _do(self, cfg, fn, *args, **kw): + for skip in self.skips: + if skip(cfg): + msg = "'%s' : %s" % ( + config.get_current_test_name(), + skip._as_string(cfg), + ) + config.skip_test(msg) + + try: + return_value = fn(*args, **kw) + except Exception as ex: + self._expect_failure(cfg, ex, name=fn.__name__) + else: + self._expect_success(cfg, name=fn.__name__) + return return_value + + def _expect_failure(self, config, ex, name="block"): + for fail in self.fails: + if fail(config): + if util.py2k: + str_ex = unicode(ex).encode( # noqa: F821 + "utf-8", errors="ignore" + ) + else: + str_ex = str(ex) + print( + ( + "%s failed as expected (%s): %s " + % (name, fail._as_string(config), str_ex) + ) + ) + break + else: + util.raise_(ex, with_traceback=sys.exc_info()[2]) + + def _expect_success(self, config, name="block"): + if not self.fails: + return + + for fail in self.fails: + if fail(config): + raise AssertionError( + "Unexpected success for '%s' (%s)" + % ( + name, + " and ".join( + fail._as_string(config) for fail in self.fails + ), + ) + ) + + +def requires_tag(tagname): + return tags([tagname]) + + +def tags(tagnames): + comp = compound() + comp.tags.update(tagnames) + return comp + + +def only_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return skip_if(NotPredicate(predicate), reason) + + +def succeeds_if(predicate, reason=None): + predicate = _as_predicate(predicate) + return fails_if(NotPredicate(predicate), reason) + + +class Predicate(object): + @classmethod + def as_predicate(cls, predicate, description=None): + if isinstance(predicate, compound): + return cls.as_predicate(predicate.enabled_for_config, description) + elif isinstance(predicate, Predicate): + if description and predicate.description is None: + predicate.description = description + return predicate + elif isinstance(predicate, (list, set)): + return OrPredicate( + [cls.as_predicate(pred) for pred in predicate], description + ) + elif isinstance(predicate, tuple): + return SpecPredicate(*predicate) + elif isinstance(predicate, util.string_types): + tokens = re.match( + r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate + ) + if not tokens: + raise ValueError( + "Couldn't locate DB name in predicate: %r" % predicate + ) + db = tokens.group(1) + op = tokens.group(2) + spec = ( + tuple(int(d) for d in tokens.group(3).split(".")) + if tokens.group(3) + else None + ) + + return SpecPredicate(db, op, spec, description=description) + elif callable(predicate): + return LambdaPredicate(predicate, description) + else: + assert False, "unknown predicate type: %s" % predicate + + def _format_description(self, config, negate=False): + bool_ = self(config) + if negate: + bool_ = not negate + return self.description % { + "driver": config.db.url.get_driver_name() + if config + else "<no driver>", + "database": config.db.url.get_backend_name() + if config + else "<no database>", + "doesnt_support": "doesn't support" if bool_ else "does support", + "does_support": "does support" if bool_ else "doesn't support", + } + + def _as_string(self, config=None, negate=False): + raise NotImplementedError() + + +class BooleanPredicate(Predicate): + def __init__(self, value, description=None): + self.value = value + self.description = description or "boolean %s" % value + + def __call__(self, config): + return self.value + + def _as_string(self, config, negate=False): + return self._format_description(config, negate=negate) + + +class SpecPredicate(Predicate): + def __init__(self, db, op=None, spec=None, description=None): + self.db = db + self.op = op + self.spec = spec + self.description = description + + _ops = { + "<": operator.lt, + ">": operator.gt, + "==": operator.eq, + "!=": operator.ne, + "<=": operator.le, + ">=": operator.ge, + "in": operator.contains, + "between": lambda val, pair: val >= pair[0] and val <= pair[1], + } + + def __call__(self, config): + if config is None: + return False + + engine = config.db + + if "+" in self.db: + dialect, driver = self.db.split("+") + else: + dialect, driver = self.db, None + + if dialect and engine.name != dialect: + return False + if driver is not None and engine.driver != driver: + return False + + if self.op is not None: + assert driver is None, "DBAPI version specs not supported yet" + + version = _server_version(engine) + oper = ( + hasattr(self.op, "__call__") and self.op or self._ops[self.op] + ) + return oper(version, self.spec) + else: + return True + + def _as_string(self, config, negate=False): + if self.description is not None: + return self._format_description(config) + elif self.op is None: + if negate: + return "not %s" % self.db + else: + return "%s" % self.db + else: + if negate: + return "not %s %s %s" % (self.db, self.op, self.spec) + else: + return "%s %s %s" % (self.db, self.op, self.spec) + + +class LambdaPredicate(Predicate): + def __init__(self, lambda_, description=None, args=None, kw=None): + spec = inspect_getfullargspec(lambda_) + if not spec[0]: + self.lambda_ = lambda db: lambda_() + else: + self.lambda_ = lambda_ + self.args = args or () + self.kw = kw or {} + if description: + self.description = description + elif lambda_.__doc__: + self.description = lambda_.__doc__ + else: + self.description = "custom function" + + def __call__(self, config): + return self.lambda_(config) + + def _as_string(self, config, negate=False): + return self._format_description(config) + + +class NotPredicate(Predicate): + def __init__(self, predicate, description=None): + self.predicate = predicate + self.description = description + + def __call__(self, config): + return not self.predicate(config) + + def _as_string(self, config, negate=False): + if self.description: + return self._format_description(config, not negate) + else: + return self.predicate._as_string(config, not negate) + + +class OrPredicate(Predicate): + def __init__(self, predicates, description=None): + self.predicates = predicates + self.description = description + + def __call__(self, config): + for pred in self.predicates: + if pred(config): + return True + return False + + def _eval_str(self, config, negate=False): + if negate: + conjunction = " and " + else: + conjunction = " or " + return conjunction.join( + p._as_string(config, negate=negate) for p in self.predicates + ) + + def _negation_str(self, config): + if self.description is not None: + return "Not " + self._format_description(config) + else: + return self._eval_str(config, negate=True) + + def _as_string(self, config, negate=False): + if negate: + return self._negation_str(config) + else: + if self.description is not None: + return self._format_description(config) + else: + return self._eval_str(config) + + +_as_predicate = Predicate.as_predicate + + +def _is_excluded(db, op, spec): + return SpecPredicate(db, op, spec)(config._current) + + +def _server_version(engine): + """Return a server_version_info tuple.""" + + # force metadata to be retrieved + conn = engine.connect() + version = getattr(engine.dialect, "server_version_info", None) + if version is None: + version = () + conn.close() + return version + + +def db_spec(*dbs): + return OrPredicate([Predicate.as_predicate(db) for db in dbs]) + + +def open(): # noqa + return skip_if(BooleanPredicate(False, "mark as execute")) + + +def closed(): + return skip_if(BooleanPredicate(True, "marked as skip")) + + +def fails(reason=None): + return fails_if(BooleanPredicate(True, reason or "expected to fail")) + + +@decorator +def future(fn, *arg): + return fails_if(LambdaPredicate(fn), "Future feature") + + +def fails_on(db, reason=None): + return fails_if(db, reason) + + +def fails_on_everything_except(*dbs): + return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs])) + + +def skip(db, reason=None): + return skip_if(db, reason) + + +def only_on(dbs, reason=None): + return only_if( + OrPredicate( + [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)] + ) + ) + + +def exclude(db, op, spec, reason=None): + return skip_if(SpecPredicate(db, op, spec), reason) + + +def against(config, *queries): + assert queries, "no queries sent!" + return OrPredicate([Predicate.as_predicate(query) for query in queries])( + config + ) diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py new file mode 100644 index 0000000..0a2d63b --- /dev/null +++ b/lib/sqlalchemy/testing/fixtures.py @@ -0,0 +1,870 @@ +# testing/fixtures.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import contextlib +import re +import sys + +import sqlalchemy as sa +from . import assertions +from . import config +from . import schema +from .entities import BasicEntity +from .entities import ComparableEntity +from .entities import ComparableMixin # noqa +from .util import adict +from .util import drop_all_tables_from_metadata +from .. import event +from .. import util +from ..orm import declarative_base +from ..orm import registry +from ..orm.decl_api import DeclarativeMeta +from ..schema import sort_tables_and_constraints + + +@config.mark_base_test_class() +class TestBase(object): + # A sequence of requirement names matching testing.requires decorators + __requires__ = () + + # A sequence of dialect names to exclude from the test class. + __unsupported_on__ = () + + # If present, test class is only runnable for the *single* specified + # dialect. If you need multiple, use __unsupported_on__ and invert. + __only_on__ = None + + # A sequence of no-arg callables. If any are True, the entire testcase is + # skipped. + __skip_if__ = None + + # if True, the testing reaper will not attempt to touch connection + # state after a test is completed and before the outer teardown + # starts + __leave_connections_for_teardown__ = False + + def assert_(self, val, msg=None): + assert val, msg + + @config.fixture() + def nocache(self): + _cache = config.db._compiled_cache + config.db._compiled_cache = None + yield + config.db._compiled_cache = _cache + + @config.fixture() + def connection_no_trans(self): + eng = getattr(self, "bind", None) or config.db + + with eng.connect() as conn: + yield conn + + @config.fixture() + def connection(self): + global _connection_fixture_connection + + eng = getattr(self, "bind", None) or config.db + + conn = eng.connect() + trans = conn.begin() + + _connection_fixture_connection = conn + yield conn + + _connection_fixture_connection = None + + if trans.is_active: + trans.rollback() + # trans would not be active here if the test is using + # the legacy @provide_metadata decorator still, as it will + # run a close all connections. + conn.close() + + @config.fixture() + def close_result_when_finished(self): + to_close = [] + to_consume = [] + + def go(result, consume=False): + to_close.append(result) + if consume: + to_consume.append(result) + + yield go + for r in to_consume: + try: + r.all() + except: + pass + for r in to_close: + try: + r.close() + except: + pass + + @config.fixture() + def registry(self, metadata): + reg = registry(metadata=metadata) + yield reg + reg.dispose() + + @config.fixture + def decl_base(self, registry): + return registry.generate_base() + + @config.fixture() + def future_connection(self, future_engine, connection): + # integrate the future_engine and connection fixtures so + # that users of the "connection" fixture will get at the + # "future" connection + yield connection + + @config.fixture() + def future_engine(self): + eng = getattr(self, "bind", None) or config.db + with _push_future_engine(eng): + yield + + @config.fixture() + def testing_engine(self): + from . import engines + + def gen_testing_engine( + url=None, + options=None, + future=None, + asyncio=False, + transfer_staticpool=False, + ): + if options is None: + options = {} + options["scope"] = "fixture" + return engines.testing_engine( + url=url, + options=options, + future=future, + asyncio=asyncio, + transfer_staticpool=transfer_staticpool, + ) + + yield gen_testing_engine + + engines.testing_reaper._drop_testing_engines("fixture") + + @config.fixture() + def async_testing_engine(self, testing_engine): + def go(**kw): + kw["asyncio"] = True + return testing_engine(**kw) + + return go + + @config.fixture + def fixture_session(self): + return fixture_session() + + @config.fixture() + def metadata(self, request): + """Provide bound MetaData for a single test, dropping afterwards.""" + + from ..sql import schema + + metadata = schema.MetaData() + request.instance.metadata = metadata + yield metadata + del request.instance.metadata + + if ( + _connection_fixture_connection + and _connection_fixture_connection.in_transaction() + ): + trans = _connection_fixture_connection.get_transaction() + trans.rollback() + with _connection_fixture_connection.begin(): + drop_all_tables_from_metadata( + metadata, _connection_fixture_connection + ) + else: + drop_all_tables_from_metadata(metadata, config.db) + + @config.fixture( + params=[ + (rollback, second_operation, begin_nested) + for rollback in (True, False) + for second_operation in ("none", "execute", "begin") + for begin_nested in ( + True, + False, + ) + ] + ) + def trans_ctx_manager_fixture(self, request, metadata): + rollback, second_operation, begin_nested = request.param + + from sqlalchemy import Table, Column, Integer, func, select + from . import eq_ + + t = Table("test", metadata, Column("data", Integer)) + eng = getattr(self, "bind", None) or config.db + + t.create(eng) + + def run_test(subject, trans_on_subject, execute_on_subject): + with subject.begin() as trans: + + if begin_nested: + if not config.requirements.savepoints.enabled: + config.skip_test("savepoints not enabled") + if execute_on_subject: + nested_trans = subject.begin_nested() + else: + nested_trans = trans.begin_nested() + + with nested_trans: + if execute_on_subject: + subject.execute(t.insert(), {"data": 10}) + else: + trans.execute(t.insert(), {"data": 10}) + + # for nested trans, we always commit/rollback on the + # "nested trans" object itself. + # only Session(future=False) will affect savepoint + # transaction for session.commit/rollback + + if rollback: + nested_trans.rollback() + else: + nested_trans.commit() + + if second_operation != "none": + with assertions.expect_raises_message( + sa.exc.InvalidRequestError, + "Can't operate on closed transaction " + "inside context " + "manager. Please complete the context " + "manager " + "before emitting further commands.", + ): + if second_operation == "execute": + if execute_on_subject: + subject.execute( + t.insert(), {"data": 12} + ) + else: + trans.execute(t.insert(), {"data": 12}) + elif second_operation == "begin": + if execute_on_subject: + subject.begin_nested() + else: + trans.begin_nested() + + # outside the nested trans block, but still inside the + # transaction block, we can run SQL, and it will be + # committed + if execute_on_subject: + subject.execute(t.insert(), {"data": 14}) + else: + trans.execute(t.insert(), {"data": 14}) + + else: + if execute_on_subject: + subject.execute(t.insert(), {"data": 10}) + else: + trans.execute(t.insert(), {"data": 10}) + + if trans_on_subject: + if rollback: + subject.rollback() + else: + subject.commit() + else: + if rollback: + trans.rollback() + else: + trans.commit() + + if second_operation != "none": + with assertions.expect_raises_message( + sa.exc.InvalidRequestError, + "Can't operate on closed transaction inside " + "context " + "manager. Please complete the context manager " + "before emitting further commands.", + ): + if second_operation == "execute": + if execute_on_subject: + subject.execute(t.insert(), {"data": 12}) + else: + trans.execute(t.insert(), {"data": 12}) + elif second_operation == "begin": + if hasattr(trans, "begin"): + trans.begin() + else: + subject.begin() + elif second_operation == "begin_nested": + if execute_on_subject: + subject.begin_nested() + else: + trans.begin_nested() + + expected_committed = 0 + if begin_nested: + # begin_nested variant, we inserted a row after the nested + # block + expected_committed += 1 + if not rollback: + # not rollback variant, our row inserted in the target + # block itself would be committed + expected_committed += 1 + + if execute_on_subject: + eq_( + subject.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + else: + with subject.connect() as conn: + eq_( + conn.scalar(select(func.count()).select_from(t)), + expected_committed, + ) + + return run_test + + +_connection_fixture_connection = None + + +@contextlib.contextmanager +def _push_future_engine(engine): + + from ..future.engine import Engine + from sqlalchemy import testing + + facade = Engine._future_facade(engine) + config._current.push_engine(facade, testing) + + yield facade + + config._current.pop(testing) + + +class FutureEngineMixin(object): + @config.fixture(autouse=True, scope="class") + def _push_future_engine(self): + eng = getattr(self, "bind", None) or config.db + with _push_future_engine(eng): + yield + + +class TablesTest(TestBase): + + # 'once', None + run_setup_bind = "once" + + # 'once', 'each', None + run_define_tables = "once" + + # 'once', 'each', None + run_create_tables = "once" + + # 'once', 'each', None + run_inserts = "each" + + # 'each', None + run_deletes = "each" + + # 'once', None + run_dispose_bind = None + + bind = None + _tables_metadata = None + tables = None + other = None + sequences = None + + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ + cls._init_class() + + cls._setup_once_tables() + + cls._setup_once_inserts() + + yield + + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_inserts() + + yield + + self._teardown_each_tables() + + @property + def tables_test_metadata(self): + return self._tables_metadata + + @classmethod + def _init_class(cls): + if cls.run_define_tables == "each": + if cls.run_create_tables == "once": + cls.run_create_tables = "each" + assert cls.run_inserts in ("each", None) + + cls.other = adict() + cls.tables = adict() + cls.sequences = adict() + + cls.bind = cls.setup_bind() + cls._tables_metadata = sa.MetaData() + + @classmethod + def _setup_once_inserts(cls): + if cls.run_inserts == "once": + cls._load_fixtures() + with cls.bind.begin() as conn: + cls.insert_data(conn) + + @classmethod + def _setup_once_tables(cls): + if cls.run_define_tables == "once": + cls.define_tables(cls._tables_metadata) + if cls.run_create_tables == "once": + cls._tables_metadata.create_all(cls.bind) + cls.tables.update(cls._tables_metadata.tables) + cls.sequences.update(cls._tables_metadata._sequences) + + def _setup_each_tables(self): + if self.run_define_tables == "each": + self.define_tables(self._tables_metadata) + if self.run_create_tables == "each": + self._tables_metadata.create_all(self.bind) + self.tables.update(self._tables_metadata.tables) + self.sequences.update(self._tables_metadata._sequences) + elif self.run_create_tables == "each": + self._tables_metadata.create_all(self.bind) + + def _setup_each_inserts(self): + if self.run_inserts == "each": + self._load_fixtures() + with self.bind.begin() as conn: + self.insert_data(conn) + + def _teardown_each_tables(self): + if self.run_define_tables == "each": + self.tables.clear() + if self.run_create_tables == "each": + drop_all_tables_from_metadata(self._tables_metadata, self.bind) + self._tables_metadata.clear() + elif self.run_create_tables == "each": + drop_all_tables_from_metadata(self._tables_metadata, self.bind) + + savepoints = getattr(config.requirements, "savepoints", False) + if savepoints: + savepoints = savepoints.enabled + + # no need to run deletes if tables are recreated on setup + if ( + self.run_define_tables != "each" + and self.run_create_tables != "each" + and self.run_deletes == "each" + ): + with self.bind.begin() as conn: + for table in reversed( + [ + t + for (t, fks) in sort_tables_and_constraints( + self._tables_metadata.tables.values() + ) + if t is not None + ] + ): + try: + if savepoints: + with conn.begin_nested(): + conn.execute(table.delete()) + else: + conn.execute(table.delete()) + except sa.exc.DBAPIError as ex: + util.print_( + ("Error emptying table %s: %r" % (table, ex)), + file=sys.stderr, + ) + + @classmethod + def _teardown_once_metadata_bind(cls): + if cls.run_create_tables: + drop_all_tables_from_metadata(cls._tables_metadata, cls.bind) + + if cls.run_dispose_bind == "once": + cls.dispose_bind(cls.bind) + + cls._tables_metadata.bind = None + + if cls.run_setup_bind is not None: + cls.bind = None + + @classmethod + def setup_bind(cls): + return config.db + + @classmethod + def dispose_bind(cls, bind): + if hasattr(bind, "dispose"): + bind.dispose() + elif hasattr(bind, "close"): + bind.close() + + @classmethod + def define_tables(cls, metadata): + pass + + @classmethod + def fixtures(cls): + return {} + + @classmethod + def insert_data(cls, connection): + pass + + def sql_count_(self, count, fn): + self.assert_sql_count(self.bind, fn, count) + + def sql_eq_(self, callable_, statements): + self.assert_sql(self.bind, callable_, statements) + + @classmethod + def _load_fixtures(cls): + """Insert rows as represented by the fixtures() method.""" + headers, rows = {}, {} + for table, data in cls.fixtures().items(): + if len(data) < 2: + continue + if isinstance(table, util.string_types): + table = cls.tables[table] + headers[table] = data[0] + rows[table] = data[1:] + for table, fks in sort_tables_and_constraints( + cls._tables_metadata.tables.values() + ): + if table is None: + continue + if table not in headers: + continue + with cls.bind.begin() as conn: + conn.execute( + table.insert(), + [ + dict(zip(headers[table], column_values)) + for column_values in rows[table] + ], + ) + + +class NoCache(object): + @config.fixture(autouse=True, scope="function") + def _disable_cache(self): + _cache = config.db._compiled_cache + config.db._compiled_cache = None + yield + config.db._compiled_cache = _cache + + +class RemovesEvents(object): + @util.memoized_property + def _event_fns(self): + return set() + + def event_listen(self, target, name, fn, **kw): + self._event_fns.add((target, name, fn)) + event.listen(target, name, fn, **kw) + + @config.fixture(autouse=True, scope="function") + def _remove_events(self): + yield + for key in self._event_fns: + event.remove(*key) + + +_fixture_sessions = set() + + +def fixture_session(**kw): + kw.setdefault("autoflush", True) + kw.setdefault("expire_on_commit", True) + + bind = kw.pop("bind", config.db) + + sess = sa.orm.Session(bind, **kw) + _fixture_sessions.add(sess) + return sess + + +def _close_all_sessions(): + # will close all still-referenced sessions + sa.orm.session.close_all_sessions() + _fixture_sessions.clear() + + +def stop_test_class_inside_fixtures(cls): + _close_all_sessions() + sa.orm.clear_mappers() + + +def after_test(): + if _fixture_sessions: + _close_all_sessions() + + +class ORMTest(TestBase): + pass + + +class MappedTest(TablesTest, assertions.AssertsExecutionResults): + # 'once', 'each', None + run_setup_classes = "once" + + # 'once', 'each', None + run_setup_mappers = "each" + + classes = None + + @config.fixture(autouse=True, scope="class") + def _setup_tables_test_class(self): + cls = self.__class__ + cls._init_class() + + if cls.classes is None: + cls.classes = adict() + + cls._setup_once_tables() + cls._setup_once_classes() + cls._setup_once_mappers() + cls._setup_once_inserts() + + yield + + cls._teardown_once_class() + cls._teardown_once_metadata_bind() + + @config.fixture(autouse=True, scope="function") + def _setup_tables_test_instance(self): + self._setup_each_tables() + self._setup_each_classes() + self._setup_each_mappers() + self._setup_each_inserts() + + yield + + sa.orm.session.close_all_sessions() + self._teardown_each_mappers() + self._teardown_each_classes() + self._teardown_each_tables() + + @classmethod + def _teardown_once_class(cls): + cls.classes.clear() + + @classmethod + def _setup_once_classes(cls): + if cls.run_setup_classes == "once": + cls._with_register_classes(cls.setup_classes) + + @classmethod + def _setup_once_mappers(cls): + if cls.run_setup_mappers == "once": + cls.mapper_registry, cls.mapper = cls._generate_registry() + cls._with_register_classes(cls.setup_mappers) + + def _setup_each_mappers(self): + if self.run_setup_mappers != "once": + ( + self.__class__.mapper_registry, + self.__class__.mapper, + ) = self._generate_registry() + + if self.run_setup_mappers == "each": + self._with_register_classes(self.setup_mappers) + + def _setup_each_classes(self): + if self.run_setup_classes == "each": + self._with_register_classes(self.setup_classes) + + @classmethod + def _generate_registry(cls): + decl = registry(metadata=cls._tables_metadata) + return decl, decl.map_imperatively + + @classmethod + def _with_register_classes(cls, fn): + """Run a setup method, framing the operation with a Base class + that will catch new subclasses to be established within + the "classes" registry. + + """ + cls_registry = cls.classes + + assert cls_registry is not None + + class FindFixture(type): + def __init__(cls, classname, bases, dict_): + cls_registry[classname] = cls + type.__init__(cls, classname, bases, dict_) + + class _Base(util.with_metaclass(FindFixture, object)): + pass + + class Basic(BasicEntity, _Base): + pass + + class Comparable(ComparableEntity, _Base): + pass + + cls.Basic = Basic + cls.Comparable = Comparable + fn() + + def _teardown_each_mappers(self): + # some tests create mappers in the test bodies + # and will define setup_mappers as None - + # clear mappers in any case + if self.run_setup_mappers != "once": + sa.orm.clear_mappers() + + def _teardown_each_classes(self): + if self.run_setup_classes != "once": + self.classes.clear() + + @classmethod + def setup_classes(cls): + pass + + @classmethod + def setup_mappers(cls): + pass + + +class DeclarativeMappedTest(MappedTest): + run_setup_classes = "once" + run_setup_mappers = "once" + + @classmethod + def _setup_once_tables(cls): + pass + + @classmethod + def _with_register_classes(cls, fn): + cls_registry = cls.classes + + class FindFixtureDeclarative(DeclarativeMeta): + def __init__(cls, classname, bases, dict_): + cls_registry[classname] = cls + DeclarativeMeta.__init__(cls, classname, bases, dict_) + + class DeclarativeBasic(object): + __table_cls__ = schema.Table + + _DeclBase = declarative_base( + metadata=cls._tables_metadata, + metaclass=FindFixtureDeclarative, + cls=DeclarativeBasic, + ) + + cls.DeclarativeBasic = _DeclBase + + # sets up cls.Basic which is helpful for things like composite + # classes + super(DeclarativeMappedTest, cls)._with_register_classes(fn) + + if cls._tables_metadata.tables and cls.run_create_tables: + cls._tables_metadata.create_all(config.db) + + +class ComputedReflectionFixtureTest(TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("computed_columns", "table_reflection") + + regexp = re.compile(r"[\[\]\(\)\s`'\"]*") + + def normalize(self, text): + return self.regexp.sub("", text).lower() + + @classmethod + def define_tables(cls, metadata): + from .. import Integer + from .. import testing + from ..schema import Column + from ..schema import Computed + from ..schema import Table + + Table( + "computed_default_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_col", Integer, Computed("normal + 42")), + Column("with_default", Integer, server_default="42"), + ) + + t = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal + 42")), + ) + + if testing.requires.schemas.enabled: + t2 = Table( + "computed_column_table", + metadata, + Column("id", Integer, primary_key=True), + Column("normal", Integer), + Column("computed_no_flag", Integer, Computed("normal / 42")), + schema=config.test_schema, + ) + + if testing.requires.computed_columns_virtual.enabled: + t.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal + 2", persisted=False), + ) + ) + if testing.requires.schemas.enabled: + t2.append_column( + Column( + "computed_virtual", + Integer, + Computed("normal / 2", persisted=False), + ) + ) + if testing.requires.computed_columns_stored.enabled: + t.append_column( + Column( + "computed_stored", + Integer, + Computed("normal - 42", persisted=True), + ) + ) + if testing.requires.schemas.enabled: + t2.append_column( + Column( + "computed_stored", + Integer, + Computed("normal * 42", persisted=True), + ) + ) diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py new file mode 100644 index 0000000..e333c70 --- /dev/null +++ b/lib/sqlalchemy/testing/mock.py @@ -0,0 +1,32 @@ +# testing/mock.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Import stub for mock library. +""" +from __future__ import absolute_import + +from ..util import py3k + + +if py3k: + from unittest.mock import MagicMock + from unittest.mock import Mock + from unittest.mock import call + from unittest.mock import patch + from unittest.mock import ANY +else: + try: + from mock import MagicMock # noqa + from mock import Mock # noqa + from mock import call # noqa + from mock import patch # noqa + from mock import ANY # noqa + except ImportError: + raise ImportError( + "SQLAlchemy's test suite requires the " + "'mock' library as of 0.8.2." + ) diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py new file mode 100644 index 0000000..f05960c --- /dev/null +++ b/lib/sqlalchemy/testing/pickleable.py @@ -0,0 +1,151 @@ +# testing/pickleable.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Classes used in pickling tests, need to be at the module level for +unpickling. +""" + +from . import fixtures +from ..schema import Column +from ..types import String + + +class User(fixtures.ComparableEntity): + pass + + +class Order(fixtures.ComparableEntity): + pass + + +class Dingaling(fixtures.ComparableEntity): + pass + + +class EmailUser(User): + pass + + +class Address(fixtures.ComparableEntity): + pass + + +# TODO: these are kind of arbitrary.... +class Child1(fixtures.ComparableEntity): + pass + + +class Child2(fixtures.ComparableEntity): + pass + + +class Parent(fixtures.ComparableEntity): + pass + + +class Screen(object): + def __init__(self, obj, parent=None): + self.obj = obj + self.parent = parent + + +class Mixin(object): + email_address = Column(String) + + +class AddressWMixin(Mixin, fixtures.ComparableEntity): + pass + + +class Foo(object): + def __init__(self, moredata, stuff="im stuff"): + self.data = "im data" + self.stuff = stuff + self.moredata = moredata + + __hash__ = object.__hash__ + + def __eq__(self, other): + return ( + other.data == self.data + and other.stuff == self.stuff + and other.moredata == self.moredata + ) + + +class Bar(object): + def __init__(self, x, y): + self.x = x + self.y = y + + __hash__ = object.__hash__ + + def __eq__(self, other): + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) + + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class OldSchool: + def __init__(self, x, y): + self.x = x + self.y = y + + def __eq__(self, other): + return ( + other.__class__ is self.__class__ + and other.x == self.x + and other.y == self.y + ) + + +class OldSchoolWithoutCompare: + def __init__(self, x, y): + self.x = x + self.y = y + + +class BarWithoutCompare(object): + def __init__(self, x, y): + self.x = x + self.y = y + + def __str__(self): + return "Bar(%d, %d)" % (self.x, self.y) + + +class NotComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return NotImplemented + + def __ne__(self, other): + return NotImplemented + + +class BrokenComparable(object): + def __init__(self, data): + self.data = data + + def __hash__(self): + return id(self) + + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + raise NotImplementedError diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/__init__.py diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py new file mode 100644 index 0000000..6721f48 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/bootstrap.py @@ -0,0 +1,54 @@ +""" +Bootstrapper for test framework plugins. + +The entire rationale for this system is to get the modules in plugin/ +imported without importing all of the supporting library, so that we can +set up things for testing before coverage starts. + +The rationale for all of plugin/ being *in* the supporting library in the +first place is so that the testing and plugin suite is available to other +libraries, mainly external SQLAlchemy and Alembic dialects, to make use +of the same test environment and standard suites available to +SQLAlchemy/Alembic themselves without the need to ship/install a separate +package outside of SQLAlchemy. + + +""" + +import os +import sys + + +bootstrap_file = locals()["bootstrap_file"] +to_bootstrap = locals()["to_bootstrap"] + + +def load_file_as_module(name): + path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name) + + if sys.version_info >= (3, 5): + import importlib.util + + spec = importlib.util.spec_from_file_location(name, path) + assert spec is not None + assert spec.loader is not None + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + else: + import imp + + mod = imp.load_source(name, path) + + return mod + + +if to_bootstrap == "pytest": + sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base") + sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True + if sys.version_info < (3, 0): + sys.modules["sqla_reinvent_fixtures"] = load_file_as_module( + "reinvent_fixtures_py2k" + ) + sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin") +else: + raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py new file mode 100644 index 0000000..d59564e --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/plugin_base.py @@ -0,0 +1,789 @@ +# plugin/plugin_base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Testing extensions. + +this module is designed to work as a testing-framework-agnostic library, +created so that multiple test frameworks can be supported at once +(mostly so that we can migrate to new ones). The current target +is pytest. + +""" + +from __future__ import absolute_import + +import abc +import logging +import re +import sys + +# flag which indicates we are in the SQLAlchemy testing suite, +# and not that of Alembic or a third party dialect. +bootstrapped_as_sqlalchemy = False + +log = logging.getLogger("sqlalchemy.testing.plugin_base") + + +py3k = sys.version_info >= (3, 0) + +if py3k: + import configparser + + ABC = abc.ABC +else: + import ConfigParser as configparser + import collections as collections_abc # noqa + + class ABC(object): + __metaclass__ = abc.ABCMeta + + +# late imports +fixtures = None +engines = None +exclusions = None +warnings = None +profiling = None +provision = None +assertions = None +requirements = None +config = None +testing = None +util = None +file_config = None + +logging = None +include_tags = set() +exclude_tags = set() +options = None + + +def setup_options(make_option): + make_option( + "--log-info", + action="callback", + type=str, + callback=_log, + help="turn on info logging for <LOG> (multiple OK)", + ) + make_option( + "--log-debug", + action="callback", + type=str, + callback=_log, + help="turn on debug logging for <LOG> (multiple OK)", + ) + make_option( + "--db", + action="append", + type=str, + dest="db", + help="Use prefab database uri. Multiple OK, " + "first one is run by default.", + ) + make_option( + "--dbs", + action="callback", + zeroarg_callback=_list_dbs, + help="List available prefab dbs", + ) + make_option( + "--dburi", + action="append", + type=str, + dest="dburi", + help="Database uri. Multiple OK, " "first one is run by default.", + ) + make_option( + "--dbdriver", + action="append", + type=str, + dest="dbdriver", + help="Additional database drivers to include in tests. " + "These are linked to the existing database URLs by the " + "provisioning system.", + ) + make_option( + "--dropfirst", + action="store_true", + dest="dropfirst", + help="Drop all tables in the target database first", + ) + make_option( + "--disable-asyncio", + action="store_true", + help="disable test / fixtures / provisoning running in asyncio", + ) + make_option( + "--backend-only", + action="store_true", + dest="backend_only", + help="Run only tests marked with __backend__ or __sparse_backend__", + ) + make_option( + "--nomemory", + action="store_true", + dest="nomemory", + help="Don't run memory profiling tests", + ) + make_option( + "--notimingintensive", + action="store_true", + dest="notimingintensive", + help="Don't run timing intensive tests", + ) + make_option( + "--profile-sort", + type=str, + default="cumulative", + dest="profilesort", + help="Type of sort for profiling standard output", + ) + make_option( + "--profile-dump", + type=str, + dest="profiledump", + help="Filename where a single profile run will be dumped", + ) + make_option( + "--postgresql-templatedb", + type=str, + help="name of template database to use for PostgreSQL " + "CREATE DATABASE (defaults to current database)", + ) + make_option( + "--low-connections", + action="store_true", + dest="low_connections", + help="Use a low number of distinct connections - " + "i.e. for Oracle TNS", + ) + make_option( + "--write-idents", + type=str, + dest="write_idents", + help="write out generated follower idents to <file>, " + "when -n<num> is used", + ) + make_option( + "--reversetop", + action="store_true", + dest="reversetop", + default=False, + help="Use a random-ordering set implementation in the ORM " + "(helps reveal dependency issues)", + ) + make_option( + "--requirements", + action="callback", + type=str, + callback=_requirements_opt, + help="requirements class for testing, overrides setup.cfg", + ) + make_option( + "--with-cdecimal", + action="store_true", + dest="cdecimal", + default=False, + help="Monkeypatch the cdecimal library into Python 'decimal' " + "for all tests", + ) + make_option( + "--include-tag", + action="callback", + callback=_include_tag, + type=str, + help="Include tests with tag <tag>", + ) + make_option( + "--exclude-tag", + action="callback", + callback=_exclude_tag, + type=str, + help="Exclude tests with tag <tag>", + ) + make_option( + "--write-profiles", + action="store_true", + dest="write_profiles", + default=False, + help="Write/update failing profiling data.", + ) + make_option( + "--force-write-profiles", + action="store_true", + dest="force_write_profiles", + default=False, + help="Unconditionally write/update profiling data.", + ) + make_option( + "--dump-pyannotate", + type=str, + dest="dump_pyannotate", + help="Run pyannotate and dump json info to given file", + ) + make_option( + "--mypy-extra-test-path", + type=str, + action="append", + default=[], + dest="mypy_extra_test_paths", + help="Additional test directories to add to the mypy tests. " + "This is used only when running mypy tests. Multiple OK", + ) + + +def configure_follower(follower_ident): + """Configure required state for a follower. + + This invokes in the parent process and typically includes + database creation. + + """ + from sqlalchemy.testing import provision + + provision.FOLLOWER_IDENT = follower_ident + + +def memoize_important_follower_config(dict_): + """Store important configuration we will need to send to a follower. + + This invokes in the parent process after normal config is set up. + + This is necessary as pytest seems to not be using forking, so we + start with nothing in memory, *but* it isn't running our argparse + callables, so we have to just copy all of that over. + + """ + dict_["memoized_config"] = { + "include_tags": include_tags, + "exclude_tags": exclude_tags, + } + + +def restore_important_follower_config(dict_): + """Restore important configuration needed by a follower. + + This invokes in the follower process. + + """ + global include_tags, exclude_tags + include_tags.update(dict_["memoized_config"]["include_tags"]) + exclude_tags.update(dict_["memoized_config"]["exclude_tags"]) + + +def read_config(): + global file_config + file_config = configparser.ConfigParser() + file_config.read(["setup.cfg", "test.cfg"]) + + +def pre_begin(opt): + """things to set up early, before coverage might be setup.""" + global options + options = opt + for fn in pre_configure: + fn(options, file_config) + + +def set_coverage_flag(value): + options.has_coverage = value + + +def post_begin(): + """things to set up later, once we know coverage is running.""" + # Lazy setup of other options (post coverage) + for fn in post_configure: + fn(options, file_config) + + # late imports, has to happen after config. + global util, fixtures, engines, exclusions, assertions, provision + global warnings, profiling, config, testing + from sqlalchemy import testing # noqa + from sqlalchemy.testing import fixtures, engines, exclusions # noqa + from sqlalchemy.testing import assertions, warnings, profiling # noqa + from sqlalchemy.testing import config, provision # noqa + from sqlalchemy import util # noqa + + warnings.setup_filters() + + +def _log(opt_str, value, parser): + global logging + if not logging: + import logging + + logging.basicConfig() + + if opt_str.endswith("-info"): + logging.getLogger(value).setLevel(logging.INFO) + elif opt_str.endswith("-debug"): + logging.getLogger(value).setLevel(logging.DEBUG) + + +def _list_dbs(*args): + print("Available --db options (use --dburi to override)") + for macro in sorted(file_config.options("db")): + print("%20s\t%s" % (macro, file_config.get("db", macro))) + sys.exit(0) + + +def _requirements_opt(opt_str, value, parser): + _setup_requirements(value) + + +def _exclude_tag(opt_str, value, parser): + exclude_tags.add(value.replace("-", "_")) + + +def _include_tag(opt_str, value, parser): + include_tags.add(value.replace("-", "_")) + + +pre_configure = [] +post_configure = [] + + +def pre(fn): + pre_configure.append(fn) + return fn + + +def post(fn): + post_configure.append(fn) + return fn + + +@pre +def _setup_options(opt, file_config): + global options + options = opt + + +@pre +def _set_nomemory(opt, file_config): + if opt.nomemory: + exclude_tags.add("memory_intensive") + + +@pre +def _set_notimingintensive(opt, file_config): + if opt.notimingintensive: + exclude_tags.add("timing_intensive") + + +@pre +def _monkeypatch_cdecimal(options, file_config): + if options.cdecimal: + import cdecimal + + sys.modules["decimal"] = cdecimal + + +@post +def _init_symbols(options, file_config): + from sqlalchemy.testing import config + + config._fixture_functions = _fixture_fn_class() + + +@post +def _set_disable_asyncio(opt, file_config): + if opt.disable_asyncio or not py3k: + from sqlalchemy.testing import asyncio + + asyncio.ENABLE_ASYNCIO = False + + +@post +def _engine_uri(options, file_config): + + from sqlalchemy import testing + from sqlalchemy.testing import config + from sqlalchemy.testing import provision + + if options.dburi: + db_urls = list(options.dburi) + else: + db_urls = [] + + extra_drivers = options.dbdriver or [] + + if options.db: + for db_token in options.db: + for db in re.split(r"[,\s]+", db_token): + if db not in file_config.options("db"): + raise RuntimeError( + "Unknown URI specifier '%s'. " + "Specify --dbs for known uris." % db + ) + else: + db_urls.append(file_config.get("db", db)) + + if not db_urls: + db_urls.append(file_config.get("db", "default")) + + config._current = None + + expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers)) + + for db_url in expanded_urls: + log.info("Adding database URL: %s", db_url) + + if options.write_idents and provision.FOLLOWER_IDENT: + with open(options.write_idents, "a") as file_: + file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n") + + cfg = provision.setup_config( + db_url, options, file_config, provision.FOLLOWER_IDENT + ) + if not config._current: + cfg.set_as_current(cfg, testing) + + +@post +def _requirements(options, file_config): + + requirement_cls = file_config.get("sqla_testing", "requirement_cls") + _setup_requirements(requirement_cls) + + +def _setup_requirements(argument): + from sqlalchemy.testing import config + from sqlalchemy import testing + + if config.requirements is not None: + return + + modname, clsname = argument.split(":") + + # importlib.import_module() only introduced in 2.7, a little + # late + mod = __import__(modname) + for component in modname.split(".")[1:]: + mod = getattr(mod, component) + req_cls = getattr(mod, clsname) + + config.requirements = testing.requires = req_cls() + + config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy + + +@post +def _prep_testing_database(options, file_config): + from sqlalchemy.testing import config + + if options.dropfirst: + from sqlalchemy.testing import provision + + for cfg in config.Config.all_configs(): + provision.drop_all_schema_objects(cfg, cfg.db) + + +@post +def _reverse_topological(options, file_config): + if options.reversetop: + from sqlalchemy.orm.util import randomize_unitofwork + + randomize_unitofwork() + + +@post +def _post_setup_options(opt, file_config): + from sqlalchemy.testing import config + + config.options = options + config.file_config = file_config + + +@post +def _setup_profiling(options, file_config): + from sqlalchemy.testing import profiling + + profiling._profile_stats = profiling.ProfileStatsFile( + file_config.get("sqla_testing", "profile_file"), + sort=options.profilesort, + dump=options.profiledump, + ) + + +def want_class(name, cls): + if not issubclass(cls, fixtures.TestBase): + return False + elif name.startswith("_"): + return False + elif ( + config.options.backend_only + and not getattr(cls, "__backend__", False) + and not getattr(cls, "__sparse_backend__", False) + and not getattr(cls, "__only_on__", False) + ): + return False + else: + return True + + +def want_method(cls, fn): + if not fn.__name__.startswith("test_"): + return False + elif fn.__module__ is None: + return False + elif include_tags: + return ( + hasattr(cls, "__tags__") + and exclusions.tags(cls.__tags__).include_test( + include_tags, exclude_tags + ) + ) or ( + hasattr(fn, "_sa_exclusion_extend") + and fn._sa_exclusion_extend.include_test( + include_tags, exclude_tags + ) + ) + elif exclude_tags and hasattr(cls, "__tags__"): + return exclusions.tags(cls.__tags__).include_test( + include_tags, exclude_tags + ) + elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"): + return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags) + else: + return True + + +def generate_sub_tests(cls, module): + if getattr(cls, "__backend__", False) or getattr( + cls, "__sparse_backend__", False + ): + sparse = getattr(cls, "__sparse_backend__", False) + for cfg in _possible_configs_for_cls(cls, sparse=sparse): + orig_name = cls.__name__ + + # we can have special chars in these names except for the + # pytest junit plugin, which is tripped up by the brackets + # and periods, so sanitize + + alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name) + alpha_name = re.sub(r"_+$", "", alpha_name) + name = "%s_%s" % (cls.__name__, alpha_name) + subcls = type( + name, + (cls,), + {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg}, + ) + setattr(module, name, subcls) + yield subcls + else: + yield cls + + +def start_test_class_outside_fixtures(cls): + _do_skips(cls) + _setup_engine(cls) + + +def stop_test_class(cls): + # close sessions, immediate connections, etc. + fixtures.stop_test_class_inside_fixtures(cls) + + # close outstanding connection pool connections, dispose of + # additional engines + engines.testing_reaper.stop_test_class_inside_fixtures() + + +def stop_test_class_outside_fixtures(cls): + engines.testing_reaper.stop_test_class_outside_fixtures() + provision.stop_test_class_outside_fixtures(config, config.db, cls) + try: + if not options.low_connections: + assertions.global_cleanup_assertions() + finally: + _restore_engine() + + +def _restore_engine(): + if config._current: + config._current.reset(testing) + + +def final_process_cleanup(): + engines.testing_reaper.final_cleanup() + assertions.global_cleanup_assertions() + _restore_engine() + + +def _setup_engine(cls): + if getattr(cls, "__engine_options__", None): + opts = dict(cls.__engine_options__) + opts["scope"] = "class" + eng = engines.testing_engine(options=opts) + config._current.push_engine(eng, testing) + + +def before_test(test, test_module_name, test_class, test_name): + + # format looks like: + # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause" + + name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__) + + id_ = "%s.%s.%s" % (test_module_name, name, test_name) + + profiling._start_current_test(id_) + + +def after_test(test): + fixtures.after_test() + engines.testing_reaper.after_test() + + +def after_test_fixtures(test): + engines.testing_reaper.after_test_outside_fixtures(test) + + +def _possible_configs_for_cls(cls, reasons=None, sparse=False): + all_configs = set(config.Config.all_configs()) + + if cls.__unsupported_on__: + spec = exclusions.db_spec(*cls.__unsupported_on__) + for config_obj in list(all_configs): + if spec(config_obj): + all_configs.remove(config_obj) + + if getattr(cls, "__only_on__", None): + spec = exclusions.db_spec(*util.to_list(cls.__only_on__)) + for config_obj in list(all_configs): + if not spec(config_obj): + all_configs.remove(config_obj) + + if getattr(cls, "__only_on_config__", None): + all_configs.intersection_update([cls.__only_on_config__]) + + if hasattr(cls, "__requires__"): + requirements = config.requirements + for config_obj in list(all_configs): + for requirement in cls.__requires__: + check = getattr(requirements, requirement) + + skip_reasons = check.matching_config_reasons(config_obj) + if skip_reasons: + all_configs.remove(config_obj) + if reasons is not None: + reasons.extend(skip_reasons) + break + + if hasattr(cls, "__prefer_requires__"): + non_preferred = set() + requirements = config.requirements + for config_obj in list(all_configs): + for requirement in cls.__prefer_requires__: + check = getattr(requirements, requirement) + + if not check.enabled_for_config(config_obj): + non_preferred.add(config_obj) + if all_configs.difference(non_preferred): + all_configs.difference_update(non_preferred) + + if sparse: + # pick only one config from each base dialect + # sorted so we get the same backend each time selecting the highest + # server version info. + per_dialect = {} + for cfg in reversed( + sorted( + all_configs, + key=lambda cfg: ( + cfg.db.name, + cfg.db.driver, + cfg.db.dialect.server_version_info, + ), + ) + ): + db = cfg.db.name + if db not in per_dialect: + per_dialect[db] = cfg + return per_dialect.values() + + return all_configs + + +def _do_skips(cls): + reasons = [] + all_configs = _possible_configs_for_cls(cls, reasons) + + if getattr(cls, "__skip_if__", False): + for c in getattr(cls, "__skip_if__"): + if c(): + config.skip_test( + "'%s' skipped by %s" % (cls.__name__, c.__name__) + ) + + if not all_configs: + msg = "'%s' unsupported on any DB implementation %s%s" % ( + cls.__name__, + ", ".join( + "'%s(%s)+%s'" + % ( + config_obj.db.name, + ".".join( + str(dig) + for dig in exclusions._server_version(config_obj.db) + ), + config_obj.db.driver, + ) + for config_obj in config.Config.all_configs() + ), + ", ".join(reasons), + ) + config.skip_test(msg) + elif hasattr(cls, "__prefer_backends__"): + non_preferred = set() + spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__)) + for config_obj in all_configs: + if not spec(config_obj): + non_preferred.add(config_obj) + if all_configs.difference(non_preferred): + all_configs.difference_update(non_preferred) + + if config._current not in all_configs: + _setup_config(all_configs.pop(), cls) + + +def _setup_config(config_obj, ctx): + config._current.push(config_obj, testing) + + +class FixtureFunctions(ABC): + @abc.abstractmethod + def skip_test_exception(self, *arg, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def combinations(self, *args, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def param_ident(self, *args, **kw): + raise NotImplementedError() + + @abc.abstractmethod + def fixture(self, *arg, **kw): + raise NotImplementedError() + + def get_current_test_name(self): + raise NotImplementedError() + + @abc.abstractmethod + def mark_base_test_class(self): + raise NotImplementedError() + + +_fixture_fn_class = None + + +def set_fixture_functions(fixture_fn_class): + global _fixture_fn_class + _fixture_fn_class = fixture_fn_class diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py new file mode 100644 index 0000000..5a51582 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -0,0 +1,820 @@ +try: + # installed by bootstrap.py + import sqla_plugin_base as plugin_base +except ImportError: + # assume we're a package, use traditional import + from . import plugin_base + +import argparse +import collections +from functools import update_wrapper +import inspect +import itertools +import operator +import os +import re +import sys +import uuid + +import pytest + + +py2k = sys.version_info < (3, 0) +if py2k: + try: + import sqla_reinvent_fixtures as reinvent_fixtures_py2k + except ImportError: + from . import reinvent_fixtures_py2k + + +def pytest_addoption(parser): + group = parser.getgroup("sqlalchemy") + + def make_option(name, **kw): + callback_ = kw.pop("callback", None) + if callback_: + + class CallableAction(argparse.Action): + def __call__( + self, parser, namespace, values, option_string=None + ): + callback_(option_string, values, parser) + + kw["action"] = CallableAction + + zeroarg_callback = kw.pop("zeroarg_callback", None) + if zeroarg_callback: + + class CallableAction(argparse.Action): + def __init__( + self, + option_strings, + dest, + default=False, + required=False, + help=None, # noqa + ): + super(CallableAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=True, + default=default, + required=required, + help=help, + ) + + def __call__( + self, parser, namespace, values, option_string=None + ): + zeroarg_callback(option_string, values, parser) + + kw["action"] = CallableAction + + group.addoption(name, **kw) + + plugin_base.setup_options(make_option) + plugin_base.read_config() + + +def pytest_configure(config): + if config.pluginmanager.hasplugin("xdist"): + config.pluginmanager.register(XDistHooks()) + + if hasattr(config, "workerinput"): + plugin_base.restore_important_follower_config(config.workerinput) + plugin_base.configure_follower(config.workerinput["follower_ident"]) + else: + if config.option.write_idents and os.path.exists( + config.option.write_idents + ): + os.remove(config.option.write_idents) + + plugin_base.pre_begin(config.option) + + plugin_base.set_coverage_flag( + bool(getattr(config.option, "cov_source", False)) + ) + + plugin_base.set_fixture_functions(PytestFixtureFunctions) + + if config.option.dump_pyannotate: + global DUMP_PYANNOTATE + DUMP_PYANNOTATE = True + + +DUMP_PYANNOTATE = False + + +@pytest.fixture(autouse=True) +def collect_types_fixture(): + if DUMP_PYANNOTATE: + from pyannotate_runtime import collect_types + + collect_types.start() + yield + if DUMP_PYANNOTATE: + collect_types.stop() + + +def pytest_sessionstart(session): + from sqlalchemy.testing import asyncio + + asyncio._assume_async(plugin_base.post_begin) + + +def pytest_sessionfinish(session): + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup) + + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + collect_types.dump_stats(session.config.option.dump_pyannotate) + + +def pytest_collection_finish(session): + if session.config.option.dump_pyannotate: + from pyannotate_runtime import collect_types + + lib_sqlalchemy = os.path.abspath("lib/sqlalchemy") + + def _filter(filename): + filename = os.path.normpath(os.path.abspath(filename)) + if "lib/sqlalchemy" not in os.path.commonpath( + [filename, lib_sqlalchemy] + ): + return None + if "testing" in filename: + return None + + return filename + + collect_types.init_types_collection(filter_filename=_filter) + + +class XDistHooks(object): + def pytest_configure_node(self, node): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + + # the master for each node fills workerinput dictionary + # which pytest-xdist will transfer to the subprocess + + plugin_base.memoize_important_follower_config(node.workerinput) + + node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12] + + asyncio._maybe_async_provisioning( + provision.create_follower_db, node.workerinput["follower_ident"] + ) + + def pytest_testnodedown(self, node, error): + from sqlalchemy.testing import provision + from sqlalchemy.testing import asyncio + + asyncio._maybe_async_provisioning( + provision.drop_follower_db, node.workerinput["follower_ident"] + ) + + +def pytest_collection_modifyitems(session, config, items): + + # look for all those classes that specify __backend__ and + # expand them out into per-database test cases. + + # this is much easier to do within pytest_pycollect_makeitem, however + # pytest is iterating through cls.__dict__ as makeitem is + # called which causes a "dictionary changed size" error on py3k. + # I'd submit a pullreq for them to turn it into a list first, but + # it's to suit the rather odd use case here which is that we are adding + # new classes to a module on the fly. + + from sqlalchemy.testing import asyncio + + rebuilt_items = collections.defaultdict( + lambda: collections.defaultdict(list) + ) + + items[:] = [ + item + for item in items + if item.getparent(pytest.Class) is not None + and not item.getparent(pytest.Class).name.startswith("_") + ] + + test_classes = set(item.getparent(pytest.Class) for item in items) + + def collect(element): + for inst_or_fn in element.collect(): + if isinstance(inst_or_fn, pytest.Collector): + # no yield from in 2.7 + for el in collect(inst_or_fn): + yield el + else: + yield inst_or_fn + + def setup_test_classes(): + for test_class in test_classes: + for sub_cls in plugin_base.generate_sub_tests( + test_class.cls, test_class.module + ): + if sub_cls is not test_class.cls: + per_cls_dict = rebuilt_items[test_class.cls] + + # support pytest 5.4.0 and above pytest.Class.from_parent + ctor = getattr(pytest.Class, "from_parent", pytest.Class) + module = test_class.getparent(pytest.Module) + for fn in collect( + ctor(name=sub_cls.__name__, parent=module) + ): + per_cls_dict[fn.name].append(fn) + + # class requirements will sometimes need to access the DB to check + # capabilities, so need to do this for async + asyncio._maybe_async_provisioning(setup_test_classes) + + newitems = [] + for item in items: + cls_ = item.cls + if cls_ in rebuilt_items: + newitems.extend(rebuilt_items[cls_][item.name]) + else: + newitems.append(item) + + if py2k: + for item in newitems: + reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item) + + # seems like the functions attached to a test class aren't sorted already? + # is that true and why's that? (when using unittest, they're sorted) + items[:] = sorted( + newitems, + key=lambda item: ( + item.getparent(pytest.Module).name, + item.getparent(pytest.Class).name, + item.name, + ), + ) + + +def pytest_pycollect_makeitem(collector, name, obj): + if inspect.isclass(obj) and plugin_base.want_class(name, obj): + from sqlalchemy.testing import config + + if config.any_async: + obj = _apply_maybe_async(obj) + + ctor = getattr(pytest.Class, "from_parent", pytest.Class) + return [ + ctor(name=parametrize_cls.__name__, parent=collector) + for parametrize_cls in _parametrize_cls(collector.module, obj) + ] + elif ( + inspect.isfunction(obj) + and collector.cls is not None + and plugin_base.want_method(collector.cls, obj) + ): + # None means, fall back to default logic, which includes + # method-level parametrize + return None + else: + # empty list means skip this item + return [] + + +def _is_wrapped_coroutine_function(fn): + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + + return inspect.iscoroutinefunction(fn) + + +def _apply_maybe_async(obj, recurse=True): + from sqlalchemy.testing import asyncio + + for name, value in vars(obj).items(): + if ( + (callable(value) or isinstance(value, classmethod)) + and not getattr(value, "_maybe_async_applied", False) + and (name.startswith("test_")) + and not _is_wrapped_coroutine_function(value) + ): + is_classmethod = False + if isinstance(value, classmethod): + value = value.__func__ + is_classmethod = True + + @_pytest_fn_decorator + def make_async(fn, *args, **kwargs): + return asyncio._maybe_async(fn, *args, **kwargs) + + do_async = make_async(value) + if is_classmethod: + do_async = classmethod(do_async) + do_async._maybe_async_applied = True + + setattr(obj, name, do_async) + if recurse: + for cls in obj.mro()[1:]: + if cls != object: + _apply_maybe_async(cls, False) + return obj + + +def _parametrize_cls(module, cls): + """implement a class-based version of pytest parametrize.""" + + if "_sa_parametrize" not in cls.__dict__: + return [cls] + + _sa_parametrize = cls._sa_parametrize + classes = [] + for full_param_set in itertools.product( + *[params for argname, params in _sa_parametrize] + ): + cls_variables = {} + + for argname, param in zip( + [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set + ): + if not argname: + raise TypeError("need argnames for class-based combinations") + argname_split = re.split(r",\s*", argname) + for arg, val in zip(argname_split, param.values): + cls_variables[arg] = val + parametrized_name = "_".join( + # token is a string, but in py2k pytest is giving us a unicode, + # so call str() on it. + str(re.sub(r"\W", "", token)) + for param in full_param_set + for token in param.id.split("-") + ) + name = "%s_%s" % (cls.__name__, parametrized_name) + newcls = type.__new__(type, name, (cls,), cls_variables) + setattr(module, name, newcls) + classes.append(newcls) + return classes + + +_current_class = None + + +def pytest_runtest_setup(item): + from sqlalchemy.testing import asyncio + + # pytest_runtest_setup runs *before* pytest fixtures with scope="class". + # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest + # for the whole class and has to run things that are across all current + # databases, so we run this outside of the pytest fixture system altogether + # and ensure asyncio greenlet if any engines are async + + global _current_class + + if isinstance(item, pytest.Function) and _current_class is None: + asyncio._maybe_async_provisioning( + plugin_base.start_test_class_outside_fixtures, + item.cls, + ) + _current_class = item.getparent(pytest.Class) + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_teardown(item, nextitem): + # runs inside of pytest function fixture scope + # after test function runs + from sqlalchemy.testing import asyncio + from sqlalchemy.util import string_types + + asyncio._maybe_async(plugin_base.after_test, item) + + yield + # this is now after all the fixture teardown have run, the class can be + # finalized. Since pytest v7 this finalizer can no longer be added in + # pytest_runtest_setup since the class has not yet been setup at that + # time. + # See https://github.com/pytest-dev/pytest/issues/9343 + global _current_class, _current_report + + if _current_class is not None and ( + # last test or a new class + nextitem is None + or nextitem.getparent(pytest.Class) is not _current_class + ): + _current_class = None + + try: + asyncio._maybe_async_provisioning( + plugin_base.stop_test_class_outside_fixtures, item.cls + ) + except Exception as e: + # in case of an exception during teardown attach the original + # error to the exception message, otherwise it will get lost + if _current_report.failed: + if not e.args: + e.args = ( + "__Original test failure__:\n" + + _current_report.longreprtext, + ) + elif e.args[-1] and isinstance(e.args[-1], string_types): + args = list(e.args) + args[-1] += ( + "\n__Original test failure__:\n" + + _current_report.longreprtext + ) + e.args = tuple(args) + else: + e.args += ( + "__Original test failure__", + _current_report.longreprtext, + ) + raise + finally: + _current_report = None + + +def pytest_runtest_call(item): + # runs inside of pytest function fixture scope + # before test function runs + + from sqlalchemy.testing import asyncio + + asyncio._maybe_async( + plugin_base.before_test, + item, + item.module.__name__, + item.cls, + item.name, + ) + + +_current_report = None + + +def pytest_runtest_logreport(report): + global _current_report + if report.when == "call": + _current_report = report + + +@pytest.fixture(scope="class") +def setup_class_methods(request): + from sqlalchemy.testing import asyncio + + cls = request.cls + + if hasattr(cls, "setup_test_class"): + asyncio._maybe_async(cls.setup_test_class) + + if py2k: + reinvent_fixtures_py2k.run_class_fixture_setup(request) + + yield + + if py2k: + reinvent_fixtures_py2k.run_class_fixture_teardown(request) + + if hasattr(cls, "teardown_test_class"): + asyncio._maybe_async(cls.teardown_test_class) + + asyncio._maybe_async(plugin_base.stop_test_class, cls) + + +@pytest.fixture(scope="function") +def setup_test_methods(request): + from sqlalchemy.testing import asyncio + + # called for each test + + self = request.instance + + # before this fixture runs: + + # 1. function level "autouse" fixtures under py3k (examples: TablesTest + # define tables / data, MappedTest define tables / mappers / data) + + # 2. run homegrown function level "autouse" fixtures under py2k + if py2k: + reinvent_fixtures_py2k.run_fn_fixture_setup(request) + + # 3. run outer xdist-style setup + if hasattr(self, "setup_test"): + asyncio._maybe_async(self.setup_test) + + # alembic test suite is using setUp and tearDown + # xdist methods; support these in the test suite + # for the near term + if hasattr(self, "setUp"): + asyncio._maybe_async(self.setUp) + + # inside the yield: + # 4. function level fixtures defined on test functions themselves, + # e.g. "connection", "metadata" run next + + # 5. pytest hook pytest_runtest_call then runs + + # 6. test itself runs + + yield + + # yield finishes: + + # 7. function level fixtures defined on test functions + # themselves, e.g. "connection" rolls back the transaction, "metadata" + # emits drop all + + # 8. pytest hook pytest_runtest_teardown hook runs, this is associated + # with fixtures close all sessions, provisioning.stop_test_class(), + # engines.testing_reaper -> ensure all connection pool connections + # are returned, engines created by testing_engine that aren't the + # config engine are disposed + + asyncio._maybe_async(plugin_base.after_test_fixtures, self) + + # 10. run xdist-style teardown + if hasattr(self, "tearDown"): + asyncio._maybe_async(self.tearDown) + + if hasattr(self, "teardown_test"): + asyncio._maybe_async(self.teardown_test) + + # 11. run homegrown function-level "autouse" fixtures under py2k + if py2k: + reinvent_fixtures_py2k.run_fn_fixture_teardown(request) + + # 12. function level "autouse" fixtures under py3k (examples: TablesTest / + # MappedTest delete table data, possibly drop tables and clear mappers + # depending on the flags defined by the test class) + + +def getargspec(fn): + if sys.version_info.major == 3: + return inspect.getfullargspec(fn) + else: + return inspect.getargspec(fn) + + +def _pytest_fn_decorator(target): + """Port of langhelpers.decorator with pytest-specific tricks.""" + + from sqlalchemy.util.langhelpers import format_argspec_plus + from sqlalchemy.util.compat import inspect_getfullargspec + + def _exec_code_in_env(code, env, fn_name): + exec(code, env) + return env[fn_name] + + def decorate(fn, add_positional_parameters=()): + + spec = inspect_getfullargspec(fn) + if add_positional_parameters: + spec.args.extend(add_positional_parameters) + + metadata = dict( + __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__ + ) + metadata.update(format_argspec_plus(spec, grouped=False)) + code = ( + """\ +def %(name)s(%(args)s): + return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s) +""" + % metadata + ) + decorated = _exec_code_in_env( + code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__ + ) + if not add_positional_parameters: + decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__ + decorated.__wrapped__ = fn + return update_wrapper(decorated, fn) + else: + # this is the pytest hacky part. don't do a full update wrapper + # because pytest is really being sneaky about finding the args + # for the wrapped function + decorated.__module__ = fn.__module__ + decorated.__name__ = fn.__name__ + if hasattr(fn, "pytestmark"): + decorated.pytestmark = fn.pytestmark + return decorated + + return decorate + + +class PytestFixtureFunctions(plugin_base.FixtureFunctions): + def skip_test_exception(self, *arg, **kw): + return pytest.skip.Exception(*arg, **kw) + + def mark_base_test_class(self): + return pytest.mark.usefixtures( + "setup_class_methods", "setup_test_methods" + ) + + _combination_id_fns = { + "i": lambda obj: obj, + "r": repr, + "s": str, + "n": lambda obj: obj.__name__ + if hasattr(obj, "__name__") + else type(obj).__name__, + } + + def combinations(self, *arg_sets, **kw): + """Facade for pytest.mark.parametrize. + + Automatically derives argument names from the callable which in our + case is always a method on a class with positional arguments. + + ids for parameter sets are derived using an optional template. + + """ + from sqlalchemy.testing import exclusions + + if sys.version_info.major == 3: + if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"): + arg_sets = list(arg_sets[0]) + else: + if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"): + arg_sets = list(arg_sets[0]) + + argnames = kw.pop("argnames", None) + + def _filter_exclusions(args): + result = [] + gathered_exclusions = [] + for a in args: + if isinstance(a, exclusions.compound): + gathered_exclusions.append(a) + else: + result.append(a) + + return result, gathered_exclusions + + id_ = kw.pop("id_", None) + + tobuild_pytest_params = [] + has_exclusions = False + if id_: + _combination_id_fns = self._combination_id_fns + + # because itemgetter is not consistent for one argument vs. + # multiple, make it multiple in all cases and use a slice + # to omit the first argument + _arg_getter = operator.itemgetter( + 0, + *[ + idx + for idx, char in enumerate(id_) + if char in ("n", "r", "s", "a") + ] + ) + fns = [ + (operator.itemgetter(idx), _combination_id_fns[char]) + for idx, char in enumerate(id_) + if char in _combination_id_fns + ] + + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + parameters = _arg_getter(fn_params)[1:] + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + ( + parameters, + param_exclusions, + "-".join( + comb_fn(getter(arg)) for getter, comb_fn in fns + ), + ) + ) + + else: + + for arg in arg_sets: + if not isinstance(arg, tuple): + arg = (arg,) + + fn_params, param_exclusions = _filter_exclusions(arg) + + if param_exclusions: + has_exclusions = True + + tobuild_pytest_params.append( + (fn_params, param_exclusions, None) + ) + + pytest_params = [] + for parameters, param_exclusions, id_ in tobuild_pytest_params: + if has_exclusions: + parameters += (param_exclusions,) + + param = pytest.param(*parameters, id=id_) + pytest_params.append(param) + + def decorate(fn): + if inspect.isclass(fn): + if has_exclusions: + raise NotImplementedError( + "exclusions not supported for class level combinations" + ) + if "_sa_parametrize" not in fn.__dict__: + fn._sa_parametrize = [] + fn._sa_parametrize.append((argnames, pytest_params)) + return fn + else: + if argnames is None: + _argnames = getargspec(fn).args[1:] + else: + _argnames = re.split(r", *", argnames) + + if has_exclusions: + _argnames += ["_exclusions"] + + @_pytest_fn_decorator + def check_exclusions(fn, *args, **kw): + _exclusions = args[-1] + if _exclusions: + exlu = exclusions.compound().add(*_exclusions) + fn = exlu(fn) + return fn(*args[0:-1], **kw) + + def process_metadata(spec): + spec.args.append("_exclusions") + + fn = check_exclusions( + fn, add_positional_parameters=("_exclusions",) + ) + + return pytest.mark.parametrize(_argnames, pytest_params)(fn) + + return decorate + + def param_ident(self, *parameters): + ident = parameters[0] + return pytest.param(*parameters[1:], id=ident) + + def fixture(self, *arg, **kw): + from sqlalchemy.testing import config + from sqlalchemy.testing import asyncio + + # wrapping pytest.fixture function. determine if + # decorator was called as @fixture or @fixture(). + if len(arg) > 0 and callable(arg[0]): + # was called as @fixture(), we have the function to wrap. + fn = arg[0] + arg = arg[1:] + else: + # was called as @fixture, don't have the function yet. + fn = None + + # create a pytest.fixture marker. because the fn is not being + # passed, this is always a pytest.FixtureFunctionMarker() + # object (or whatever pytest is calling it when you read this) + # that is waiting for a function. + fixture = pytest.fixture(*arg, **kw) + + # now apply wrappers to the function, including fixture itself + + def wrap(fn): + if config.any_async: + fn = asyncio._maybe_async_wrapper(fn) + # other wrappers may be added here + + if py2k and "autouse" in kw: + # py2k workaround for too-slow collection of autouse fixtures + # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for + # rationale. + + # comment this condition out in order to disable the + # py2k workaround entirely. + reinvent_fixtures_py2k.add_fixture(fn, fixture) + else: + # now apply FixtureFunctionMarker + fn = fixture(fn) + + return fn + + if fn: + return wrap(fn) + else: + return wrap + + def get_current_test_name(self): + return os.environ.get("PYTEST_CURRENT_TEST") + + def async_test(self, fn): + from sqlalchemy.testing import asyncio + + @_pytest_fn_decorator + def decorate(fn, *args, **kwargs): + asyncio._run_coroutine_function(fn, *args, **kwargs) + + return decorate(fn) diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py new file mode 100644 index 0000000..36b6841 --- /dev/null +++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py @@ -0,0 +1,112 @@ +""" +invent a quick version of pytest autouse fixtures as pytest's unacceptably slow +collection/high memory use in pytest 4.6.11, which is the highest version that +works in py2k. + +by "too-slow" we mean the test suite can't even manage to be collected for a +single process in less than 70 seconds or so and memory use seems to be very +high as well. for two or four workers the job just times out after ten +minutes. + +so instead we have invented a very limited form of these fixtures, as our +current use of "autouse" fixtures are limited to those in fixtures.py. + +assumptions for these fixtures: + +1. we are only using "function" or "class" scope + +2. the functions must be associated with a test class + +3. the fixture functions cannot themselves use pytest fixtures + +4. the fixture functions must use yield, not return + +When py2k support is removed and we can stay on a modern pytest version, this +can all be removed. + + +""" +import collections + + +_py2k_fixture_fn_names = collections.defaultdict(set) +_py2k_class_fixtures = collections.defaultdict( + lambda: collections.defaultdict(set) +) +_py2k_function_fixtures = collections.defaultdict( + lambda: collections.defaultdict(set) +) + +_py2k_cls_fixture_stack = [] +_py2k_fn_fixture_stack = [] + + +def add_fixture(fn, fixture): + assert fixture.scope in ("class", "function") + _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope)) + + +def scan_for_fixtures_to_use_for_class(item): + test_class = item.parent.parent.obj + + for name in _py2k_fixture_fn_names: + for fixture_fn, scope in _py2k_fixture_fn_names[name]: + meth = getattr(test_class, name, None) + if meth and meth.im_func is fixture_fn: + for sup in test_class.__mro__: + if name in sup.__dict__: + if scope == "class": + _py2k_class_fixtures[test_class][sup].add(meth) + elif scope == "function": + _py2k_function_fixtures[test_class][sup].add(meth) + break + break + + +def run_class_fixture_setup(request): + + cls = request.cls + self = cls.__new__(cls) + + fixtures_for_this_class = _py2k_class_fixtures.get(cls) + + if fixtures_for_this_class: + for sup_ in cls.__mro__: + for fn in fixtures_for_this_class.get(sup_, ()): + iter_ = fn(self) + next(iter_) + + _py2k_cls_fixture_stack.append(iter_) + + +def run_class_fixture_teardown(request): + while _py2k_cls_fixture_stack: + iter_ = _py2k_cls_fixture_stack.pop(-1) + try: + next(iter_) + except StopIteration: + pass + + +def run_fn_fixture_setup(request): + cls = request.cls + self = request.instance + + fixtures_for_this_class = _py2k_function_fixtures.get(cls) + + if fixtures_for_this_class: + for sup_ in reversed(cls.__mro__): + for fn in fixtures_for_this_class.get(sup_, ()): + iter_ = fn(self) + next(iter_) + + _py2k_fn_fixture_stack.append(iter_) + + +def run_fn_fixture_teardown(request): + while _py2k_fn_fixture_stack: + iter_ = _py2k_fn_fixture_stack.pop(-1) + try: + next(iter_) + except StopIteration: + pass diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py new file mode 100644 index 0000000..4132630 --- /dev/null +++ b/lib/sqlalchemy/testing/profiling.py @@ -0,0 +1,335 @@ +# testing/profiling.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Profiling support for unit and performance tests. + +These are special purpose profiling methods which operate +in a more fine-grained way than nose's profiling plugin. + +""" + +import collections +import contextlib +import os +import platform +import pstats +import re +import sys + +from . import config +from .util import gc_collect +from ..util import has_compiled_ext + + +try: + import cProfile +except ImportError: + cProfile = None + +_profile_stats = None +"""global ProfileStatsFileInstance. + +plugin_base assigns this at the start of all tests. + +""" + + +_current_test = None +"""String id of current test. + +plugin_base assigns this at the start of each test using +_start_current_test. + +""" + + +def _start_current_test(id_): + global _current_test + _current_test = id_ + + if _profile_stats.force_write: + _profile_stats.reset_count() + + +class ProfileStatsFile(object): + """Store per-platform/fn profiling results in a file. + + There was no json module available when this was written, but now + the file format which is very deterministically line oriented is kind of + handy in any case for diffs and merges. + + """ + + def __init__(self, filename, sort="cumulative", dump=None): + self.force_write = ( + config.options is not None and config.options.force_write_profiles + ) + self.write = self.force_write or ( + config.options is not None and config.options.write_profiles + ) + self.fname = os.path.abspath(filename) + self.short_fname = os.path.split(self.fname)[-1] + self.data = collections.defaultdict( + lambda: collections.defaultdict(dict) + ) + self.dump = dump + self.sort = sort + self._read() + if self.write: + # rewrite for the case where features changed, + # etc. + self._write() + + @property + def platform_key(self): + + dbapi_key = config.db.name + "_" + config.db.driver + + if config.db.name == "sqlite" and config.db.dialect._is_url_file_db( + config.db.url + ): + dbapi_key += "_file" + + # keep it at 2.7, 3.1, 3.2, etc. for now. + py_version = ".".join([str(v) for v in sys.version_info[0:2]]) + + platform_tokens = [ + platform.machine(), + platform.system().lower(), + platform.python_implementation().lower(), + py_version, + dbapi_key, + ] + + platform_tokens.append( + "nativeunicode" + if config.db.dialect.convert_unicode + else "dbapiunicode" + ) + _has_cext = has_compiled_ext() + platform_tokens.append(_has_cext and "cextensions" or "nocextensions") + return "_".join(platform_tokens) + + def has_stats(self): + test_key = _current_test + return ( + test_key in self.data and self.platform_key in self.data[test_key] + ) + + def result(self, callcount): + test_key = _current_test + per_fn = self.data[test_key] + per_platform = per_fn[self.platform_key] + + if "counts" not in per_platform: + per_platform["counts"] = counts = [] + else: + counts = per_platform["counts"] + + if "current_count" not in per_platform: + per_platform["current_count"] = current_count = 0 + else: + current_count = per_platform["current_count"] + + has_count = len(counts) > current_count + + if not has_count: + counts.append(callcount) + if self.write: + self._write() + result = None + else: + result = per_platform["lineno"], counts[current_count] + per_platform["current_count"] += 1 + return result + + def reset_count(self): + test_key = _current_test + # since self.data is a defaultdict, don't access a key + # if we don't know it's there first. + if test_key not in self.data: + return + per_fn = self.data[test_key] + if self.platform_key not in per_fn: + return + per_platform = per_fn[self.platform_key] + if "counts" in per_platform: + per_platform["counts"][:] = [] + + def replace(self, callcount): + test_key = _current_test + per_fn = self.data[test_key] + per_platform = per_fn[self.platform_key] + counts = per_platform["counts"] + current_count = per_platform["current_count"] + if current_count < len(counts): + counts[current_count - 1] = callcount + else: + counts[-1] = callcount + if self.write: + self._write() + + def _header(self): + return ( + "# %s\n" + "# This file is written out on a per-environment basis.\n" + "# For each test in aaa_profiling, the corresponding " + "function and \n" + "# environment is located within this file. " + "If it doesn't exist,\n" + "# the test is skipped.\n" + "# If a callcount does exist, it is compared " + "to what we received. \n" + "# assertions are raised if the counts do not match.\n" + "# \n" + "# To add a new callcount test, apply the function_call_count \n" + "# decorator and re-run the tests using the --write-profiles \n" + "# option - this file will be rewritten including the new count.\n" + "# \n" + ) % (self.fname) + + def _read(self): + try: + profile_f = open(self.fname) + except IOError: + return + for lineno, line in enumerate(profile_f): + line = line.strip() + if not line or line.startswith("#"): + continue + + test_key, platform_key, counts = line.split() + per_fn = self.data[test_key] + per_platform = per_fn[platform_key] + c = [int(count) for count in counts.split(",")] + per_platform["counts"] = c + per_platform["lineno"] = lineno + 1 + per_platform["current_count"] = 0 + profile_f.close() + + def _write(self): + print(("Writing profile file %s" % self.fname)) + profile_f = open(self.fname, "w") + profile_f.write(self._header()) + for test_key in sorted(self.data): + + per_fn = self.data[test_key] + profile_f.write("\n# TEST: %s\n\n" % test_key) + for platform_key in sorted(per_fn): + per_platform = per_fn[platform_key] + c = ",".join(str(count) for count in per_platform["counts"]) + profile_f.write("%s %s %s\n" % (test_key, platform_key, c)) + profile_f.close() + + +def function_call_count(variance=0.05, times=1, warmup=0): + """Assert a target for a test case's function call count. + + The main purpose of this assertion is to detect changes in + callcounts for various functions - the actual number is not as important. + Callcounts are stored in a file keyed to Python version and OS platform + information. This file is generated automatically for new tests, + and versioned so that unexpected changes in callcounts will be detected. + + """ + + # use signature-rewriting decorator function so that pytest fixtures + # still work on py27. In Py3, update_wrapper() alone is good enough, + # likely due to the introduction of __signature__. + + from sqlalchemy.util import decorator + from sqlalchemy.util import deprecations + from sqlalchemy.engine import row + from sqlalchemy.testing import mock + + @decorator + def wrap(fn, *args, **kw): + + with mock.patch.object( + deprecations, "SQLALCHEMY_WARN_20", False + ), mock.patch.object( + row.LegacyRow, "_default_key_style", row.KEY_OBJECTS_NO_WARN + ): + for warm in range(warmup): + fn(*args, **kw) + + timerange = range(times) + with count_functions(variance=variance): + for time in timerange: + rv = fn(*args, **kw) + return rv + + return wrap + + +@contextlib.contextmanager +def count_functions(variance=0.05): + if cProfile is None: + raise config._skip_test_exception("cProfile is not installed") + + if not _profile_stats.has_stats() and not _profile_stats.write: + config.skip_test( + "No profiling stats available on this " + "platform for this function. Run tests with " + "--write-profiles to add statistics to %s for " + "this platform." % _profile_stats.short_fname + ) + + gc_collect() + + pr = cProfile.Profile() + pr.enable() + # began = time.time() + yield + # ended = time.time() + pr.disable() + + # s = compat.StringIO() + stats = pstats.Stats(pr, stream=sys.stdout) + + # timespent = ended - began + callcount = stats.total_calls + + expected = _profile_stats.result(callcount) + + if expected is None: + expected_count = None + else: + line_no, expected_count = expected + + print(("Pstats calls: %d Expected %s" % (callcount, expected_count))) + stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort)) + stats.print_stats() + if _profile_stats.dump: + base, ext = os.path.splitext(_profile_stats.dump) + test_name = _current_test.split(".")[-1] + dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile") + stats.dump_stats(dumpfile) + print("Dumped stats to file %s" % dumpfile) + # stats.print_callers() + if _profile_stats.force_write: + _profile_stats.replace(callcount) + elif expected_count: + deviance = int(callcount * variance) + failed = abs(callcount - expected_count) > deviance + + if failed: + if _profile_stats.write: + _profile_stats.replace(callcount) + else: + raise AssertionError( + "Adjusted function call count %s not within %s%% " + "of expected %s, platform %s. Rerun with " + "--write-profiles to " + "regenerate this callcount." + % ( + callcount, + (variance * 100), + expected_count, + _profile_stats.platform_key, + ) + ) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py new file mode 100644 index 0000000..90c4d93 --- /dev/null +++ b/lib/sqlalchemy/testing/provision.py @@ -0,0 +1,416 @@ +import collections +import logging + +from . import config +from . import engines +from . import util +from .. import exc +from .. import inspect +from ..engine import url as sa_url +from ..sql import ddl +from ..sql import schema +from ..util import compat + + +log = logging.getLogger(__name__) + +FOLLOWER_IDENT = None + + +class register(object): + def __init__(self): + self.fns = {} + + @classmethod + def init(cls, fn): + return register().for_db("*")(fn) + + def for_db(self, *dbnames): + def decorate(fn): + for dbname in dbnames: + self.fns[dbname] = fn + return self + + return decorate + + def __call__(self, cfg, *arg): + if isinstance(cfg, compat.string_types): + url = sa_url.make_url(cfg) + elif isinstance(cfg, sa_url.URL): + url = cfg + else: + url = cfg.db.url + backend = url.get_backend_name() + if backend in self.fns: + return self.fns[backend](cfg, *arg) + else: + return self.fns["*"](cfg, *arg) + + +def create_follower_db(follower_ident): + for cfg in _configs_for_db_operation(): + log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url) + create_db(cfg, cfg.db, follower_ident) + + +def setup_config(db_url, options, file_config, follower_ident): + # load the dialect, which should also have it set up its provision + # hooks + + dialect = sa_url.make_url(db_url).get_dialect() + dialect.load_provisioning() + + if follower_ident: + db_url = follower_url_from_main(db_url, follower_ident) + db_opts = {} + update_db_opts(db_url, db_opts) + db_opts["scope"] = "global" + eng = engines.testing_engine(db_url, db_opts) + post_configure_engine(db_url, eng, follower_ident) + eng.connect().close() + + cfg = config.Config.register(eng, db_opts, options, file_config) + + # a symbolic name that tests can use if they need to disambiguate + # names across databases + if follower_ident: + config.ident = follower_ident + + if follower_ident: + configure_follower(cfg, follower_ident) + return cfg + + +def drop_follower_db(follower_ident): + for cfg in _configs_for_db_operation(): + log.info("DROP database %s, URI %r", follower_ident, cfg.db.url) + drop_db(cfg, cfg.db, follower_ident) + + +def generate_db_urls(db_urls, extra_drivers): + """Generate a set of URLs to test given configured URLs plus additional + driver names. + + Given:: + + --dburi postgresql://db1 \ + --dburi postgresql://db2 \ + --dburi postgresql://db2 \ + --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true + + Noting that the default postgresql driver is psycopg2, the output + would be:: + + postgresql+psycopg2://db1 + postgresql+asyncpg://db1 + postgresql+psycopg2://db2 + postgresql+psycopg2://db3 + + That is, for the driver in a --dburi, we want to keep that and use that + driver for each URL it's part of . For a driver that is only + in --dbdrivers, we want to use it just once for one of the URLs. + for a driver that is both coming from --dburi as well as --dbdrivers, + we want to keep it in that dburi. + + Driver specific query options can be specified by added them to the + driver name. For example, to enable the async fallback option for + asyncpg:: + + --dburi postgresql://db1 \ + --dbdriver=asyncpg?async_fallback=true + + """ + urls = set() + + backend_to_driver_we_already_have = collections.defaultdict(set) + + urls_plus_dialects = [ + (url_obj, url_obj.get_dialect()) + for url_obj in [sa_url.make_url(db_url) for db_url in db_urls] + ] + + for url_obj, dialect in urls_plus_dialects: + backend_to_driver_we_already_have[dialect.name].add(dialect.driver) + + backend_to_driver_we_need = {} + + for url_obj, dialect in urls_plus_dialects: + backend = dialect.name + dialect.load_provisioning() + + if backend not in backend_to_driver_we_need: + backend_to_driver_we_need[backend] = extra_per_backend = set( + extra_drivers + ).difference(backend_to_driver_we_already_have[backend]) + else: + extra_per_backend = backend_to_driver_we_need[backend] + + for driver_url in _generate_driver_urls(url_obj, extra_per_backend): + if driver_url in urls: + continue + urls.add(driver_url) + yield driver_url + + +def _generate_driver_urls(url, extra_drivers): + main_driver = url.get_driver_name() + extra_drivers.discard(main_driver) + + url = generate_driver_url(url, main_driver, "") + yield str(url) + + for drv in list(extra_drivers): + + if "?" in drv: + + driver_only, query_str = drv.split("?", 1) + + else: + driver_only = drv + query_str = None + + new_url = generate_driver_url(url, driver_only, query_str) + if new_url: + extra_drivers.remove(drv) + + yield str(new_url) + + +@register.init +def generate_driver_url(url, driver, query_str): + backend = url.get_backend_name() + + new_url = url.set( + drivername="%s+%s" % (backend, driver), + ) + if query_str: + new_url = new_url.update_query_string(query_str) + + try: + new_url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return new_url + + +def _configs_for_db_operation(): + hosts = set() + + for cfg in config.Config.all_configs(): + cfg.db.dispose() + + for cfg in config.Config.all_configs(): + url = cfg.db.url + backend = url.get_backend_name() + host_conf = (backend, url.username, url.host, url.database) + + if host_conf not in hosts: + yield cfg + hosts.add(host_conf) + + for cfg in config.Config.all_configs(): + cfg.db.dispose() + + +@register.init +def drop_all_schema_objects_pre_tables(cfg, eng): + pass + + +@register.init +def drop_all_schema_objects_post_tables(cfg, eng): + pass + + +def drop_all_schema_objects(cfg, eng): + + drop_all_schema_objects_pre_tables(cfg, eng) + + inspector = inspect(eng) + try: + view_names = inspector.get_view_names() + except NotImplementedError: + pass + else: + with eng.begin() as conn: + for vname in view_names: + conn.execute( + ddl._DropView(schema.Table(vname, schema.MetaData())) + ) + + if config.requirements.schemas.enabled_for_config(cfg): + try: + view_names = inspector.get_view_names(schema="test_schema") + except NotImplementedError: + pass + else: + with eng.begin() as conn: + for vname in view_names: + conn.execute( + ddl._DropView( + schema.Table( + vname, + schema.MetaData(), + schema="test_schema", + ) + ) + ) + + util.drop_all_tables(eng, inspector) + if config.requirements.schemas.enabled_for_config(cfg): + util.drop_all_tables(eng, inspector, schema=cfg.test_schema) + util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2) + + drop_all_schema_objects_post_tables(cfg, eng) + + if config.requirements.sequences.enabled_for_config(cfg): + with eng.begin() as conn: + for seq in inspector.get_sequence_names(): + conn.execute(ddl.DropSequence(schema.Sequence(seq))) + if config.requirements.schemas.enabled_for_config(cfg): + for schema_name in [cfg.test_schema, cfg.test_schema_2]: + for seq in inspector.get_sequence_names( + schema=schema_name + ): + conn.execute( + ddl.DropSequence( + schema.Sequence(seq, schema=schema_name) + ) + ) + + +@register.init +def create_db(cfg, eng, ident): + """Dynamically create a database for testing. + + Used when a test run will employ multiple processes, e.g., when run + via `tox` or `pytest -n4`. + """ + raise NotImplementedError( + "no DB creation routine for cfg: %s" % (eng.url,) + ) + + +@register.init +def drop_db(cfg, eng, ident): + """Drop a database that we dynamically created for testing.""" + raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,)) + + +@register.init +def update_db_opts(db_url, db_opts): + """Set database options (db_opts) for a test database that we created.""" + pass + + +@register.init +def post_configure_engine(url, engine, follower_ident): + """Perform extra steps after configuring an engine for testing. + + (For the internal dialects, currently only used by sqlite, oracle) + """ + pass + + +@register.init +def follower_url_from_main(url, ident): + """Create a connection URL for a dynamically-created test database. + + :param url: the connection URL specified when the test run was invoked + :param ident: the pytest-xdist "worker identifier" to be used as the + database name + """ + url = sa_url.make_url(url) + return url.set(database=ident) + + +@register.init +def configure_follower(cfg, ident): + """Create dialect-specific config settings for a follower database.""" + pass + + +@register.init +def run_reap_dbs(url, ident): + """Remove databases that were created during the test process, after the + process has ended. + + This is an optional step that is invoked for certain backends that do not + reliably release locks on the database as long as a process is still in + use. For the internal dialects, this is currently only necessary for + mssql and oracle. + """ + pass + + +def reap_dbs(idents_file): + log.info("Reaping databases...") + + urls = collections.defaultdict(set) + idents = collections.defaultdict(set) + dialects = {} + + with open(idents_file) as file_: + for line in file_: + line = line.strip() + db_name, db_url = line.split(" ") + url_obj = sa_url.make_url(db_url) + if db_name not in dialects: + dialects[db_name] = url_obj.get_dialect() + dialects[db_name].load_provisioning() + url_key = (url_obj.get_backend_name(), url_obj.host) + urls[url_key].add(db_url) + idents[url_key].add(db_name) + + for url_key in urls: + url = list(urls[url_key])[0] + ident = idents[url_key] + run_reap_dbs(url, ident) + + +@register.init +def temp_table_keyword_args(cfg, eng): + """Specify keyword arguments for creating a temporary Table. + + Dialect-specific implementations of this method will return the + kwargs that are passed to the Table method when creating a temporary + table for testing, e.g., in the define_temp_tables method of the + ComponentReflectionTest class in suite/test_reflection.py + """ + raise NotImplementedError( + "no temp table keyword args routine for cfg: %s" % (eng.url,) + ) + + +@register.init +def prepare_for_drop_tables(config, connection): + pass + + +@register.init +def stop_test_class_outside_fixtures(config, db, testcls): + pass + + +@register.init +def get_temp_table_name(cfg, eng, base_name): + """Specify table name for creating a temporary Table. + + Dialect-specific implementations of this method will return the + name to use when creating a temporary table for testing, + e.g., in the define_temp_tables method of the + ComponentReflectionTest class in suite/test_reflection.py + + Default to just the base name since that's what most dialects will + use. The mssql dialect's implementation will need a "#" prepended. + """ + return base_name + + +@register.init +def set_default_schema_on_connection(cfg, dbapi_connection, schema_name): + raise NotImplementedError( + "backend does not implement a schema name set function: %s" + % (cfg.db.url,) + ) diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py new file mode 100644 index 0000000..857d1fd --- /dev/null +++ b/lib/sqlalchemy/testing/requirements.py @@ -0,0 +1,1518 @@ +# testing/requirements.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Global database feature support policy. + +Provides decorators to mark tests requiring specific feature support from the +target database. + +External dialect test suites should subclass SuiteRequirements +to provide specific inclusion/exclusions. + +""" + +import platform +import sys + +from . import exclusions +from . import only_on +from .. import util +from ..pool import QueuePool + + +class Requirements(object): + pass + + +class SuiteRequirements(Requirements): + @property + def create_table(self): + """target platform can emit basic CreateTable DDL.""" + + return exclusions.open() + + @property + def drop_table(self): + """target platform can emit basic DropTable DDL.""" + + return exclusions.open() + + @property + def table_ddl_if_exists(self): + """target platform supports IF NOT EXISTS / IF EXISTS for tables.""" + + return exclusions.closed() + + @property + def index_ddl_if_exists(self): + """target platform supports IF NOT EXISTS / IF EXISTS for indexes.""" + + return exclusions.closed() + + @property + def foreign_keys(self): + """Target database must support foreign keys.""" + + return exclusions.open() + + @property + def table_value_constructor(self): + """Database / dialect supports a query like:: + + SELECT * FROM VALUES ( (c1, c2), (c1, c2), ...) + AS some_table(col1, col2) + + SQLAlchemy generates this with the :func:`_sql.values` function. + + """ + return exclusions.closed() + + @property + def standard_cursor_sql(self): + """Target database passes SQL-92 style statements to cursor.execute() + when a statement like select() or insert() is run. + + A very small portion of dialect-level tests will ensure that certain + conditions are present in SQL strings, and these tests use very basic + SQL that will work on any SQL-like platform in order to assert results. + + It's normally a given for any pep-249 DBAPI that a statement like + "SELECT id, name FROM table WHERE some_table.id=5" will work. + However, there are dialects that don't actually produce SQL Strings + and instead may work with symbolic objects instead, or dialects that + aren't working with SQL, so for those this requirement can be marked + as excluded. + + """ + + return exclusions.open() + + @property + def on_update_cascade(self): + """target database must support ON UPDATE..CASCADE behavior in + foreign keys.""" + + return exclusions.open() + + @property + def non_updating_cascade(self): + """target database must *not* support ON UPDATE..CASCADE behavior in + foreign keys.""" + return exclusions.closed() + + @property + def deferrable_fks(self): + return exclusions.closed() + + @property + def on_update_or_deferrable_fks(self): + # TODO: exclusions should be composable, + # somehow only_if([x, y]) isn't working here, negation/conjunctions + # getting confused. + return exclusions.only_if( + lambda: self.on_update_cascade.enabled + or self.deferrable_fks.enabled + ) + + @property + def queue_pool(self): + """target database is using QueuePool""" + + def go(config): + return isinstance(config.db.pool, QueuePool) + + return exclusions.only_if(go) + + @property + def self_referential_foreign_keys(self): + """Target database must support self-referential foreign keys.""" + + return exclusions.open() + + @property + def foreign_key_ddl(self): + """Target database must support the DDL phrases for FOREIGN KEY.""" + + return exclusions.open() + + @property + def named_constraints(self): + """target database must support names for constraints.""" + + return exclusions.open() + + @property + def implicitly_named_constraints(self): + """target database must apply names to unnamed constraints.""" + + return exclusions.open() + + @property + def subqueries(self): + """Target database must support subqueries.""" + + return exclusions.open() + + @property + def offset(self): + """target database can render OFFSET, or an equivalent, in a + SELECT. + """ + + return exclusions.open() + + @property + def bound_limit_offset(self): + """target database can render LIMIT and/or OFFSET using a bound + parameter + """ + + return exclusions.open() + + @property + def sql_expression_limit_offset(self): + """target database can render LIMIT and/or OFFSET with a complete + SQL expression, such as one that uses the addition operator. + parameter + """ + + return exclusions.open() + + @property + def parens_in_union_contained_select_w_limit_offset(self): + """Target database must support parenthesized SELECT in UNION + when LIMIT/OFFSET is specifically present. + + E.g. (SELECT ...) UNION (SELECT ..) + + This is known to fail on SQLite. + + """ + return exclusions.open() + + @property + def parens_in_union_contained_select_wo_limit_offset(self): + """Target database must support parenthesized SELECT in UNION + when OFFSET/LIMIT is specifically not present. + + E.g. (SELECT ... LIMIT ..) UNION (SELECT .. OFFSET ..) + + This is known to fail on SQLite. It also fails on Oracle + because without LIMIT/OFFSET, there is currently no step that + creates an additional subquery. + + """ + return exclusions.open() + + @property + def boolean_col_expressions(self): + """Target database must support boolean expressions as columns""" + + return exclusions.closed() + + @property + def nullable_booleans(self): + """Target database allows boolean columns to store NULL.""" + + return exclusions.open() + + @property + def nullsordering(self): + """Target backends that support nulls ordering.""" + + return exclusions.closed() + + @property + def standalone_binds(self): + """target database/driver supports bound parameters as column + expressions without being in the context of a typed column. + """ + return exclusions.closed() + + @property + def standalone_null_binds_whereclause(self): + """target database/driver supports bound parameters with NULL in the + WHERE clause, in situations where it has to be typed. + + """ + return exclusions.open() + + @property + def intersect(self): + """Target database must support INTERSECT or equivalent.""" + return exclusions.closed() + + @property + def except_(self): + """Target database must support EXCEPT or equivalent (i.e. MINUS).""" + return exclusions.closed() + + @property + def window_functions(self): + """Target database must support window functions.""" + return exclusions.closed() + + @property + def ctes(self): + """Target database supports CTEs""" + + return exclusions.closed() + + @property + def ctes_with_update_delete(self): + """target database supports CTES that ride on top of a normal UPDATE + or DELETE statement which refers to the CTE in a correlated subquery. + + """ + + return exclusions.closed() + + @property + def ctes_on_dml(self): + """target database supports CTES which consist of INSERT, UPDATE + or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)""" + + return exclusions.closed() + + @property + def autoincrement_insert(self): + """target platform generates new surrogate integer primary key values + when insert() is executed, excluding the pk column.""" + + return exclusions.open() + + @property + def fetch_rows_post_commit(self): + """target platform will allow cursor.fetchone() to proceed after a + COMMIT. + + Typically this refers to an INSERT statement with RETURNING which + is invoked within "autocommit". If the row can be returned + after the autocommit, then this rule can be open. + + """ + + return exclusions.open() + + @property + def group_by_complex_expression(self): + """target platform supports SQL expressions in GROUP BY + + e.g. + + SELECT x + y AS somelabel FROM table GROUP BY x + y + + """ + + return exclusions.open() + + @property + def sane_rowcount(self): + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_sane_rowcount, + "driver doesn't support 'sane' rowcount", + ) + + @property + def sane_multi_rowcount(self): + return exclusions.fails_if( + lambda config: not config.db.dialect.supports_sane_multi_rowcount, + "driver %(driver)s %(doesnt_support)s 'sane' multi row count", + ) + + @property + def sane_rowcount_w_returning(self): + return exclusions.fails_if( + lambda config: not ( + config.db.dialect.supports_sane_rowcount_returning + ), + "driver doesn't support 'sane' rowcount when returning is on", + ) + + @property + def empty_inserts(self): + """target platform supports INSERT with no values, i.e. + INSERT DEFAULT VALUES or equivalent.""" + + return exclusions.only_if( + lambda config: config.db.dialect.supports_empty_insert + or config.db.dialect.supports_default_values + or config.db.dialect.supports_default_metavalue, + "empty inserts not supported", + ) + + @property + def empty_inserts_executemany(self): + """target platform supports INSERT with no values, i.e. + INSERT DEFAULT VALUES or equivalent, within executemany()""" + + return self.empty_inserts + + @property + def insert_from_select(self): + """target platform supports INSERT from a SELECT.""" + + return exclusions.open() + + @property + def full_returning(self): + """target platform supports RETURNING completely, including + multiple rows returned. + + """ + + return exclusions.only_if( + lambda config: config.db.dialect.full_returning, + "%(database)s %(does_support)s 'RETURNING of multiple rows'", + ) + + @property + def insert_executemany_returning(self): + """target platform supports RETURNING when INSERT is used with + executemany(), e.g. multiple parameter sets, indicating + as many rows come back as do parameter sets were passed. + + """ + + return exclusions.only_if( + lambda config: config.db.dialect.insert_executemany_returning, + "%(database)s %(does_support)s 'RETURNING of " + "multiple rows with INSERT executemany'", + ) + + @property + def returning(self): + """target platform supports RETURNING for at least one row. + + .. seealso:: + + :attr:`.Requirements.full_returning` + + """ + + return exclusions.only_if( + lambda config: config.db.dialect.implicit_returning, + "%(database)s %(does_support)s 'RETURNING of a single row'", + ) + + @property + def tuple_in(self): + """Target platform supports the syntax + "(x, y) IN ((x1, y1), (x2, y2), ...)" + """ + + return exclusions.closed() + + @property + def tuple_in_w_empty(self): + """Target platform tuple IN w/ empty set""" + return self.tuple_in + + @property + def duplicate_names_in_cursor_description(self): + """target platform supports a SELECT statement that has + the same name repeated more than once in the columns list.""" + + return exclusions.open() + + @property + def denormalized_names(self): + """Target database must have 'denormalized', i.e. + UPPERCASE as case insensitive names.""" + + return exclusions.skip_if( + lambda config: not config.db.dialect.requires_name_normalize, + "Backend does not require denormalized names.", + ) + + @property + def multivalues_inserts(self): + """target database must support multiple VALUES clauses in an + INSERT statement.""" + + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_multivalues_insert, + "Backend does not support multirow inserts.", + ) + + @property + def implements_get_lastrowid(self): + """target dialect implements the executioncontext.get_lastrowid() + method without reliance on RETURNING. + + """ + return exclusions.open() + + @property + def emulated_lastrowid(self): + """target dialect retrieves cursor.lastrowid, or fetches + from a database-side function after an insert() construct executes, + within the get_lastrowid() method. + + Only dialects that "pre-execute", or need RETURNING to get last + inserted id, would return closed/fail/skip for this. + + """ + return exclusions.closed() + + @property + def emulated_lastrowid_even_with_sequences(self): + """target dialect retrieves cursor.lastrowid or an equivalent + after an insert() construct executes, even if the table has a + Sequence on it. + + """ + return exclusions.closed() + + @property + def dbapi_lastrowid(self): + """target platform includes a 'lastrowid' accessor on the DBAPI + cursor object. + + """ + return exclusions.closed() + + @property + def views(self): + """Target database must support VIEWs.""" + + return exclusions.closed() + + @property + def schemas(self): + """Target database must support external schemas, and have one + named 'test_schema'.""" + + return only_on(lambda config: config.db.dialect.supports_schemas) + + @property + def cross_schema_fk_reflection(self): + """target system must support reflection of inter-schema + foreign keys""" + return exclusions.closed() + + @property + def foreign_key_constraint_name_reflection(self): + """Target supports refleciton of FOREIGN KEY constraints and + will return the name of the constraint that was used in the + "CONSTRAINT <name> FOREIGN KEY" DDL. + + MySQL prior to version 8 and MariaDB prior to version 10.5 + don't support this. + + """ + return exclusions.closed() + + @property + def implicit_default_schema(self): + """target system has a strong concept of 'default' schema that can + be referred to implicitly. + + basically, PostgreSQL. + + """ + return exclusions.closed() + + @property + def default_schema_name_switch(self): + """target dialect implements provisioning module including + set_default_schema_on_connection""" + + return exclusions.closed() + + @property + def server_side_cursors(self): + """Target dialect must support server side cursors.""" + + return exclusions.only_if( + [lambda config: config.db.dialect.supports_server_side_cursors], + "no server side cursors support", + ) + + @property + def sequences(self): + """Target database must support SEQUENCEs.""" + + return exclusions.only_if( + [lambda config: config.db.dialect.supports_sequences], + "no sequence support", + ) + + @property + def no_sequences(self): + """the opposite of "sequences", DB does not support sequences at + all.""" + + return exclusions.NotPredicate(self.sequences) + + @property + def sequences_optional(self): + """Target database supports sequences, but also optionally + as a means of generating new PK values.""" + + return exclusions.only_if( + [ + lambda config: config.db.dialect.supports_sequences + and config.db.dialect.sequences_optional + ], + "no sequence support, or sequences not optional", + ) + + @property + def supports_lastrowid(self): + """target database / driver supports cursor.lastrowid as a means + of retrieving the last inserted primary key value. + + note that if the target DB supports sequences also, this is still + assumed to work. This is a new use case brought on by MariaDB 10.3. + + """ + return exclusions.only_if( + [lambda config: config.db.dialect.postfetch_lastrowid] + ) + + @property + def no_lastrowid_support(self): + """the opposite of supports_lastrowid""" + return exclusions.only_if( + [lambda config: not config.db.dialect.postfetch_lastrowid] + ) + + @property + def reflects_pk_names(self): + return exclusions.closed() + + @property + def table_reflection(self): + """target database has general support for table reflection""" + return exclusions.open() + + @property + def reflect_tables_no_columns(self): + """target database supports creation and reflection of tables with no + columns, or at least tables that seem to have no columns.""" + + return exclusions.closed() + + @property + def comment_reflection(self): + return exclusions.closed() + + @property + def view_column_reflection(self): + """target database must support retrieval of the columns in a view, + similarly to how a table is inspected. + + This does not include the full CREATE VIEW definition. + + """ + return self.views + + @property + def view_reflection(self): + """target database must support inspection of the full CREATE VIEW + definition.""" + return self.views + + @property + def schema_reflection(self): + return self.schemas + + @property + def primary_key_constraint_reflection(self): + return exclusions.open() + + @property + def foreign_key_constraint_reflection(self): + return exclusions.open() + + @property + def foreign_key_constraint_option_reflection_ondelete(self): + return exclusions.closed() + + @property + def fk_constraint_option_reflection_ondelete_restrict(self): + return exclusions.closed() + + @property + def fk_constraint_option_reflection_ondelete_noaction(self): + return exclusions.closed() + + @property + def foreign_key_constraint_option_reflection_onupdate(self): + return exclusions.closed() + + @property + def fk_constraint_option_reflection_onupdate_restrict(self): + return exclusions.closed() + + @property + def temp_table_reflection(self): + return exclusions.open() + + @property + def temp_table_reflect_indexes(self): + return self.temp_table_reflection + + @property + def temp_table_names(self): + """target dialect supports listing of temporary table names""" + return exclusions.closed() + + @property + def temporary_tables(self): + """target database supports temporary tables""" + return exclusions.open() + + @property + def temporary_views(self): + """target database supports temporary views""" + return exclusions.closed() + + @property + def index_reflection(self): + return exclusions.open() + + @property + def index_reflects_included_columns(self): + return exclusions.closed() + + @property + def indexes_with_ascdesc(self): + """target database supports CREATE INDEX with per-column ASC/DESC.""" + return exclusions.open() + + @property + def indexes_with_expressions(self): + """target database supports CREATE INDEX against SQL expressions.""" + return exclusions.closed() + + @property + def unique_constraint_reflection(self): + """target dialect supports reflection of unique constraints""" + return exclusions.open() + + @property + def check_constraint_reflection(self): + """target dialect supports reflection of check constraints""" + return exclusions.closed() + + @property + def duplicate_key_raises_integrity_error(self): + """target dialect raises IntegrityError when reporting an INSERT + with a primary key violation. (hint: it should) + + """ + return exclusions.open() + + @property + def unbounded_varchar(self): + """Target database must support VARCHAR with no length""" + + return exclusions.open() + + @property + def unicode_data(self): + """Target database/dialect must support Python unicode objects with + non-ASCII characters represented, delivered as bound parameters + as well as in result rows. + + """ + return exclusions.open() + + @property + def unicode_ddl(self): + """Target driver must support some degree of non-ascii symbol + names. + """ + return exclusions.closed() + + @property + def symbol_names_w_double_quote(self): + """Target driver can create tables with a name like 'some " table'""" + return exclusions.open() + + @property + def datetime_literals(self): + """target dialect supports rendering of a date, time, or datetime as a + literal string, e.g. via the TypeEngine.literal_processor() method. + + """ + + return exclusions.closed() + + @property + def datetime(self): + """target dialect supports representation of Python + datetime.datetime() objects.""" + + return exclusions.open() + + @property + def datetime_timezone(self): + """target dialect supports representation of Python + datetime.datetime() with tzinfo with DateTime(timezone=True).""" + + return exclusions.closed() + + @property + def time_timezone(self): + """target dialect supports representation of Python + datetime.time() with tzinfo with Time(timezone=True).""" + + return exclusions.closed() + + @property + def datetime_implicit_bound(self): + """target dialect when given a datetime object will bind it such + that the database server knows the object is a datetime, and not + a plain string. + + """ + return exclusions.open() + + @property + def datetime_microseconds(self): + """target dialect supports representation of Python + datetime.datetime() with microsecond objects.""" + + return exclusions.open() + + @property + def timestamp_microseconds(self): + """target dialect supports representation of Python + datetime.datetime() with microsecond objects but only + if TIMESTAMP is used.""" + return exclusions.closed() + + @property + def timestamp_microseconds_implicit_bound(self): + """target dialect when given a datetime object which also includes + a microseconds portion when using the TIMESTAMP data type + will bind it such that the database server knows + the object is a datetime with microseconds, and not a plain string. + + """ + return self.timestamp_microseconds + + @property + def datetime_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return exclusions.closed() + + @property + def date(self): + """target dialect supports representation of Python + datetime.date() objects.""" + + return exclusions.open() + + @property + def date_coerces_from_datetime(self): + """target dialect accepts a datetime object as the target + of a date column.""" + + return exclusions.open() + + @property + def date_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return exclusions.closed() + + @property + def time(self): + """target dialect supports representation of Python + datetime.time() objects.""" + + return exclusions.open() + + @property + def time_microseconds(self): + """target dialect supports representation of Python + datetime.time() with microsecond objects.""" + + return exclusions.open() + + @property + def binary_comparisons(self): + """target database/driver can allow BLOB/BINARY fields to be compared + against a bound parameter value. + """ + + return exclusions.open() + + @property + def binary_literals(self): + """target backend supports simple binary literals, e.g. an + expression like:: + + SELECT CAST('foo' AS BINARY) + + Where ``BINARY`` is the type emitted from :class:`.LargeBinary`, + e.g. it could be ``BLOB`` or similar. + + Basically fails on Oracle. + + """ + + return exclusions.open() + + @property + def autocommit(self): + """target dialect supports 'AUTOCOMMIT' as an isolation_level""" + return exclusions.closed() + + @property + def isolation_level(self): + """target dialect supports general isolation level settings. + + Note that this requirement, when enabled, also requires that + the get_isolation_levels() method be implemented. + + """ + return exclusions.closed() + + def get_isolation_levels(self, config): + """Return a structure of supported isolation levels for the current + testing dialect. + + The structure indicates to the testing suite what the expected + "default" isolation should be, as well as the other values that + are accepted. The dictionary has two keys, "default" and "supported". + The "supported" key refers to a list of all supported levels and + it should include AUTOCOMMIT if the dialect supports it. + + If the :meth:`.DefaultRequirements.isolation_level` requirement is + not open, then this method has no return value. + + E.g.:: + + >>> testing.requirements.get_isolation_levels() + { + "default": "READ_COMMITTED", + "supported": [ + "SERIALIZABLE", "READ UNCOMMITTED", + "READ COMMITTED", "REPEATABLE READ", + "AUTOCOMMIT" + ] + } + """ + + @property + def json_type(self): + """target platform implements a native JSON type.""" + + return exclusions.closed() + + @property + def json_array_indexes(self): + """target platform supports numeric array indexes + within a JSON structure""" + + return self.json_type + + @property + def json_index_supplementary_unicode_element(self): + return exclusions.open() + + @property + def legacy_unconditional_json_extract(self): + """Backend has a JSON_EXTRACT or similar function that returns a + valid JSON string in all cases. + + Used to test a legacy feature and is not needed. + + """ + return exclusions.closed() + + @property + def precision_numerics_general(self): + """target backend has general support for moderately high-precision + numerics.""" + return exclusions.open() + + @property + def precision_numerics_enotation_small(self): + """target backend supports Decimal() objects using E notation + to represent very small values.""" + return exclusions.closed() + + @property + def precision_numerics_enotation_large(self): + """target backend supports Decimal() objects using E notation + to represent very large values.""" + return exclusions.closed() + + @property + def precision_numerics_many_significant_digits(self): + """target backend supports values with many digits on both sides, + such as 319438950232418390.273596, 87673.594069654243 + + """ + return exclusions.closed() + + @property + def cast_precision_numerics_many_significant_digits(self): + """same as precision_numerics_many_significant_digits but within the + context of a CAST statement (hello MySQL) + + """ + return self.precision_numerics_many_significant_digits + + @property + def implicit_decimal_binds(self): + """target backend will return a selected Decimal as a Decimal, not + a string. + + e.g.:: + + expr = decimal.Decimal("15.7563") + + value = e.scalar( + select(literal(expr)) + ) + + assert value == expr + + See :ticket:`4036` + + """ + + return exclusions.open() + + @property + def nested_aggregates(self): + """target database can select an aggregate from a subquery that's + also using an aggregate + + """ + return exclusions.open() + + @property + def recursive_fk_cascade(self): + """target database must support ON DELETE CASCADE on a self-referential + foreign key + + """ + return exclusions.open() + + @property + def precision_numerics_retains_significant_digits(self): + """A precision numeric type will return empty significant digits, + i.e. a value such as 10.000 will come back in Decimal form with + the .000 maintained.""" + + return exclusions.closed() + + @property + def infinity_floats(self): + """The Float type can persist and load float('inf'), float('-inf').""" + + return exclusions.closed() + + @property + def precision_generic_float_type(self): + """target backend will return native floating point numbers with at + least seven decimal places when using the generic Float type. + + """ + return exclusions.open() + + @property + def floats_to_four_decimals(self): + """target backend can return a floating-point number with four + significant digits (such as 15.7563) accurately + (i.e. without FP inaccuracies, such as 15.75629997253418). + + """ + return exclusions.open() + + @property + def fetch_null_from_numeric(self): + """target backend doesn't crash when you try to select a NUMERIC + value that has a value of NULL. + + Added to support Pyodbc bug #351. + """ + + return exclusions.open() + + @property + def text_type(self): + """Target database must support an unbounded Text() " + "type such as TEXT or CLOB""" + + return exclusions.open() + + @property + def empty_strings_varchar(self): + """target database can persist/return an empty string with a + varchar. + + """ + return exclusions.open() + + @property + def empty_strings_text(self): + """target database can persist/return an empty string with an + unbounded text.""" + + return exclusions.open() + + @property + def expressions_against_unbounded_text(self): + """target database supports use of an unbounded textual field in a + WHERE clause.""" + + return exclusions.open() + + @property + def selectone(self): + """target driver must support the literal statement 'select 1'""" + return exclusions.open() + + @property + def savepoints(self): + """Target database must support savepoints.""" + + return exclusions.closed() + + @property + def two_phase_transactions(self): + """Target database must support two-phase transactions.""" + + return exclusions.closed() + + @property + def update_from(self): + """Target must support UPDATE..FROM syntax""" + return exclusions.closed() + + @property + def delete_from(self): + """Target must support DELETE FROM..FROM or DELETE..USING syntax""" + return exclusions.closed() + + @property + def update_where_target_in_subquery(self): + """Target must support UPDATE (or DELETE) where the same table is + present in a subquery in the WHERE clause. + + This is an ANSI-standard syntax that apparently MySQL can't handle, + such as:: + + UPDATE documents SET flag=1 WHERE documents.title IN + (SELECT max(documents.title) AS title + FROM documents GROUP BY documents.user_id + ) + + """ + return exclusions.open() + + @property + def mod_operator_as_percent_sign(self): + """target database must use a plain percent '%' as the 'modulus' + operator.""" + return exclusions.closed() + + @property + def percent_schema_names(self): + """target backend supports weird identifiers with percent signs + in them, e.g. 'some % column'. + + this is a very weird use case but often has problems because of + DBAPIs that use python formatting. It's not a critical use + case either. + + """ + return exclusions.closed() + + @property + def order_by_col_from_union(self): + """target database supports ordering by a column from a SELECT + inside of a UNION + + E.g. (SELECT id, ...) UNION (SELECT id, ...) ORDER BY id + + """ + return exclusions.open() + + @property + def order_by_label_with_expression(self): + """target backend supports ORDER BY a column label within an + expression. + + Basically this:: + + select data as foo from test order by foo || 'bar' + + Lots of databases including PostgreSQL don't support this, + so this is off by default. + + """ + return exclusions.closed() + + @property + def order_by_collation(self): + def check(config): + try: + self.get_order_by_collation(config) + return False + except NotImplementedError: + return True + + return exclusions.skip_if(check) + + def get_order_by_collation(self, config): + raise NotImplementedError() + + @property + def unicode_connections(self): + """Target driver must support non-ASCII characters being passed at + all. + """ + return exclusions.open() + + @property + def graceful_disconnects(self): + """Target driver must raise a DBAPI-level exception, such as + InterfaceError, when the underlying connection has been closed + and the execute() method is called. + """ + return exclusions.open() + + @property + def independent_connections(self): + """ + Target must support simultaneous, independent database connections. + """ + return exclusions.open() + + @property + def skip_mysql_on_windows(self): + """Catchall for a large variety of MySQL on Windows failures""" + return exclusions.open() + + @property + def ad_hoc_engines(self): + """Test environment must allow ad-hoc engine/connection creation. + + DBs that scale poorly for many connections, even when closed, i.e. + Oracle, may use the "--low-connections" option which flags this + requirement as not present. + + """ + return exclusions.skip_if( + lambda config: config.options.low_connections + ) + + @property + def no_windows(self): + return exclusions.skip_if(self._running_on_windows()) + + def _running_on_windows(self): + return exclusions.LambdaPredicate( + lambda: platform.system() == "Windows", + description="running on Windows", + ) + + @property + def timing_intensive(self): + return exclusions.requires_tag("timing_intensive") + + @property + def memory_intensive(self): + return exclusions.requires_tag("memory_intensive") + + @property + def threading_with_mock(self): + """Mark tests that use threading and mock at the same time - stability + issues have been observed with coverage + python 3.3 + + """ + return exclusions.skip_if( + lambda config: util.py3k and config.options.has_coverage, + "Stability issues with coverage + py3k", + ) + + @property + def sqlalchemy2_stubs(self): + def check(config): + try: + __import__("sqlalchemy-stubs.ext.mypy") + except ImportError: + return False + else: + return True + + return exclusions.only_if(check) + + @property + def python2(self): + return exclusions.skip_if( + lambda: sys.version_info >= (3,), + "Python version 2.xx is required.", + ) + + @property + def python3(self): + return exclusions.skip_if( + lambda: sys.version_info < (3,), "Python version 3.xx is required." + ) + + @property + def pep520(self): + return self.python36 + + @property + def insert_order_dicts(self): + return self.python37 + + @property + def python36(self): + return exclusions.skip_if( + lambda: sys.version_info < (3, 6), + "Python version 3.6 or greater is required.", + ) + + @property + def python37(self): + return exclusions.skip_if( + lambda: sys.version_info < (3, 7), + "Python version 3.7 or greater is required.", + ) + + @property + def dataclasses(self): + return self.python37 + + @property + def python38(self): + return exclusions.only_if( + lambda: util.py38, "Python 3.8 or above required" + ) + + @property + def cpython(self): + return exclusions.only_if( + lambda: util.cpython, "cPython interpreter needed" + ) + + @property + def patch_library(self): + def check_lib(): + try: + __import__("patch") + except ImportError: + return False + else: + return True + + return exclusions.only_if(check_lib, "patch library needed") + + @property + def non_broken_pickle(self): + from sqlalchemy.util import pickle + + return exclusions.only_if( + lambda: util.cpython + and pickle.__name__ == "cPickle" + or sys.version_info >= (3, 2), + "Needs cPickle+cPython or newer Python 3 pickle", + ) + + @property + def predictable_gc(self): + """target platform must remove all cycles unconditionally when + gc.collect() is called, as well as clean out unreferenced subclasses. + + """ + return self.cpython + + @property + def no_coverage(self): + """Test should be skipped if coverage is enabled. + + This is to block tests that exercise libraries that seem to be + sensitive to coverage, such as PostgreSQL notice logging. + + """ + return exclusions.skip_if( + lambda config: config.options.has_coverage, + "Issues observed when coverage is enabled", + ) + + def _has_mysql_on_windows(self, config): + return False + + def _has_mysql_fully_case_sensitive(self, config): + return False + + @property + def sqlite(self): + return exclusions.skip_if(lambda: not self._has_sqlite()) + + @property + def cextensions(self): + return exclusions.skip_if( + lambda: not util.has_compiled_ext(), "C extensions not installed" + ) + + def _has_sqlite(self): + from sqlalchemy import create_engine + + try: + create_engine("sqlite://") + return True + except ImportError: + return False + + @property + def async_dialect(self): + """dialect makes use of await_() to invoke operations on the DBAPI.""" + + return exclusions.closed() + + @property + def asyncio(self): + return self.greenlet + + @property + def greenlet(self): + def go(config): + try: + import greenlet # noqa: F401 + except ImportError: + return False + else: + return True + + return exclusions.only_if(go) + + @property + def computed_columns(self): + "Supports computed columns" + return exclusions.closed() + + @property + def computed_columns_stored(self): + "Supports computed columns with `persisted=True`" + return exclusions.closed() + + @property + def computed_columns_virtual(self): + "Supports computed columns with `persisted=False`" + return exclusions.closed() + + @property + def computed_columns_default_persisted(self): + """If the default persistence is virtual or stored when `persisted` + is omitted""" + return exclusions.closed() + + @property + def computed_columns_reflect_persisted(self): + """If persistence information is returned by the reflection of + computed columns""" + return exclusions.closed() + + @property + def supports_distinct_on(self): + """If a backend supports the DISTINCT ON in a select""" + return exclusions.closed() + + @property + def supports_is_distinct_from(self): + """Supports some form of "x IS [NOT] DISTINCT FROM y" construct. + Different dialects will implement their own flavour, e.g., + sqlite will emit "x IS NOT y" instead of "x IS DISTINCT FROM y". + + .. seealso:: + + :meth:`.ColumnOperators.is_distinct_from` + + """ + return exclusions.skip_if( + lambda config: not config.db.dialect.supports_is_distinct_from, + "driver doesn't support an IS DISTINCT FROM construct", + ) + + @property + def identity_columns(self): + """If a backend supports GENERATED { ALWAYS | BY DEFAULT } + AS IDENTITY""" + return exclusions.closed() + + @property + def identity_columns_standard(self): + """If a backend supports GENERATED { ALWAYS | BY DEFAULT } + AS IDENTITY with a standard syntax. + This is mainly to exclude MSSql. + """ + return exclusions.closed() + + @property + def regexp_match(self): + """backend supports the regexp_match operator.""" + return exclusions.closed() + + @property + def regexp_replace(self): + """backend supports the regexp_replace operator.""" + return exclusions.closed() + + @property + def fetch_first(self): + """backend supports the fetch first clause.""" + return exclusions.closed() + + @property + def fetch_percent(self): + """backend supports the fetch first clause with percent.""" + return exclusions.closed() + + @property + def fetch_ties(self): + """backend supports the fetch first clause with ties.""" + return exclusions.closed() + + @property + def fetch_no_order_by(self): + """backend supports the fetch first without order by""" + return exclusions.closed() + + @property + def fetch_offset_with_options(self): + """backend supports the offset when using fetch first with percent + or ties. basically this is "not mssql" + """ + return exclusions.closed() + + @property + def fetch_expression(self): + """backend supports fetch / offset with expression in them, like + + SELECT * FROM some_table + OFFSET 1 + 1 ROWS FETCH FIRST 1 + 1 ROWS ONLY + """ + return exclusions.closed() + + @property + def autoincrement_without_sequence(self): + """If autoincrement=True on a column does not require an explicit + sequence. This should be false only for oracle. + """ + return exclusions.open() + + @property + def generic_classes(self): + "If X[Y] can be implemented with ``__class_getitem__``. py3.7+" + return exclusions.only_if(lambda: util.py37) diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py new file mode 100644 index 0000000..bff07a5 --- /dev/null +++ b/lib/sqlalchemy/testing/schema.py @@ -0,0 +1,218 @@ +# testing/schema.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import sys + +from . import config +from . import exclusions +from .. import event +from .. import schema +from .. import types as sqltypes +from ..util import OrderedDict + + +__all__ = ["Table", "Column"] + +table_options = {} + + +def Table(*args, **kw): + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} + + kw.update(table_options) + + if exclusions.against(config._current, "mysql"): + if ( + "mysql_engine" not in kw + and "mysql_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mysql_engine"] = "InnoDB" + else: + kw["mysql_engine"] = "MyISAM" + elif exclusions.against(config._current, "mariadb"): + if ( + "mariadb_engine" not in kw + and "mariadb_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mariadb_engine"] = "InnoDB" + else: + kw["mariadb_engine"] = "MyISAM" + + # Apply some default cascading rules for self-referential foreign keys. + # MySQL InnoDB has some issues around selecting self-refs too. + if exclusions.against(config._current, "firebird"): + table_name = args[0] + unpack = config.db.dialect.identifier_preparer.unformat_identifiers + + # Only going after ForeignKeys in Columns. May need to + # expand to ForeignKeyConstraint too. + fks = [ + fk + for col in args + if isinstance(col, schema.Column) + for fk in col.foreign_keys + ] + + for fk in fks: + # root around in raw spec + ref = fk._colspec + if isinstance(ref, schema.Column): + name = ref.table.name + else: + # take just the table name: on FB there cannot be + # a schema, so the first element is always the + # table name, possibly followed by the field name + name = unpack(ref)[0] + if name == table_name: + if fk.ondelete is None: + fk.ondelete = "CASCADE" + if fk.onupdate is None: + fk.onupdate = "CASCADE" + + return schema.Table(*args, **kw) + + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} + + if not config.requirements.foreign_key_ddl.enabled_for_config(config): + args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] + + col = schema.Column(*args, **kw) + if test_opts.get("test_needs_autoincrement", False) and kw.get( + "primary_key", False + ): + + if col.default is None and col.server_default is None: + col.autoincrement = True + + # allow any test suite to pick up on this + col.info["test_needs_autoincrement"] = True + + # hardcoded rule for firebird, oracle; this should + # be moved out + if exclusions.against(config._current, "firebird", "oracle"): + + def add_seq(c, tbl): + c._init_items( + schema.Sequence( + _truncate_name( + config.db.dialect, tbl.name + "_" + c.name + "_seq" + ), + optional=True, + ) + ) + + event.listen(col, "after_parent_attach", add_seq, propagate=True) + return col + + +class eq_type_affinity(object): + """Helper to compare types inside of datastructures based on affinity. + + E.g.:: + + eq_( + inspect(connection).get_columns("foo"), + [ + { + "name": "id", + "type": testing.eq_type_affinity(sqltypes.INTEGER), + "nullable": False, + "default": None, + "autoincrement": False, + }, + { + "name": "data", + "type": testing.eq_type_affinity(sqltypes.NullType), + "nullable": True, + "default": None, + "autoincrement": False, + }, + ], + ) + + """ + + def __init__(self, target): + self.target = sqltypes.to_instance(target) + + def __eq__(self, other): + return self.target._type_affinity is other._type_affinity + + def __ne__(self, other): + return self.target._type_affinity is not other._type_affinity + + +class eq_clause_element(object): + """Helper to compare SQL structures based on compare()""" + + def __init__(self, target): + self.target = target + + def __eq__(self, other): + return self.target.compare(other) + + def __ne__(self, other): + return not self.target.compare(other) + + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return ( + name[0 : max(dialect.max_identifier_length - 6, 0)] + + "_" + + hex(hash(name) % 64)[2:] + ) + else: + return name + + +def pep435_enum(name): + # Implements PEP 435 in the minimal fashion needed by SQLAlchemy + __members__ = OrderedDict() + + def __init__(self, name, value, alias=None): + self.name = name + self.value = value + self.__members__[name] = self + value_to_member[value] = self + setattr(self.__class__, name, self) + if alias: + self.__members__[alias] = self + setattr(self.__class__, alias, self) + + value_to_member = {} + + @classmethod + def get(cls, value): + return value_to_member[value] + + someenum = type( + name, + (object,), + {"__members__": __members__, "__init__": __init__, "get": get}, + ) + + # getframe() trick for pickling I don't understand courtesy + # Python namedtuple() + try: + module = sys._getframe(1).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + if module is not None: + someenum.__module__ = module + + return someenum diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py new file mode 100644 index 0000000..30817e1 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -0,0 +1,13 @@ +from .test_cte import * # noqa +from .test_ddl import * # noqa +from .test_deprecations import * # noqa +from .test_dialect import * # noqa +from .test_insert import * # noqa +from .test_reflection import * # noqa +from .test_results import * # noqa +from .test_rowcount import * # noqa +from .test_select import * # noqa +from .test_sequence import * # noqa +from .test_types import * # noqa +from .test_unicode_ddl import * # noqa +from .test_update_delete import * # noqa diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py new file mode 100644 index 0000000..a94ee55 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -0,0 +1,204 @@ +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import ForeignKey +from ... import Integer +from ... import select +from ... import String +from ... import testing + + +class CTETest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("ctes",) + + run_inserts = "each" + run_deletes = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", ForeignKey("some_table.id")), + ) + + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "d1", "parent_id": None}, + {"id": 2, "data": "d2", "parent_id": 1}, + {"id": 3, "data": "d3", "parent_id": 1}, + {"id": 4, "data": "d4", "parent_id": 3}, + {"id": 5, "data": "d5", "parent_id": 3}, + ], + ) + + def test_select_nonrecursive_round_trip(self, connection): + some_table = self.tables.some_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + result = connection.execute( + select(cte.c.data).where(cte.c.data.in_(["d4", "d5"])) + ) + eq_(result.fetchall(), [("d4",)]) + + def test_select_recursive_round_trip(self, connection): + some_table = self.tables.some_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte", recursive=True) + ) + + cte_alias = cte.alias("c1") + st1 = some_table.alias() + # note that SQL Server requires this to be UNION ALL, + # can't be UNION + cte = cte.union_all( + select(st1).where(st1.c.id == cte_alias.c.parent_id) + ) + result = connection.execute( + select(cte.c.data) + .where(cte.c.data != "d2") + .order_by(cte.c.data.desc()) + ) + eq_( + result.fetchall(), + [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], + ) + + def test_insert_from_select_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(cte) + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.update_from + def test_update_from_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.update() + .values(parent_id=5) + .where(some_other_table.c.data == cte.c.data) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), + (2, "d2", 5), + (3, "d3", 5), + (4, "d4", 5), + (5, "d5", 3), + ], + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.delete_from + def test_delete_from_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data == cte.c.data + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) + + @testing.requires.ctes_with_update_delete + def test_delete_scalar_subq_round_trip(self, connection): + + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data + == select(cte.c.data) + .where(cte.c.id == some_other_table.c.id) + .scalar_subquery() + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py new file mode 100644 index 0000000..b3fee55 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -0,0 +1,381 @@ +import random + +from . import testing +from .. import config +from .. import fixtures +from .. import util +from ..assertions import eq_ +from ..assertions import is_false +from ..assertions import is_true +from ..config import requirements +from ..schema import Table +from ... import CheckConstraint +from ... import Column +from ... import ForeignKeyConstraint +from ... import Index +from ... import inspect +from ... import Integer +from ... import schema +from ... import String +from ... import UniqueConstraint + + +class TableDDLTest(fixtures.TestBase): + __backend__ = True + + def _simple_fixture(self, schema=None): + return Table( + "test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + schema=schema, + ) + + def _underscore_fixture(self): + return Table( + "_test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("_data", String(50)), + ) + + def _table_index_fixture(self, schema=None): + table = self._simple_fixture(schema=schema) + idx = Index("test_index", table.c.data) + return table, idx + + def _simple_roundtrip(self, table): + with config.db.begin() as conn: + conn.execute(table.insert().values((1, "some data"))) + result = conn.execute(table.select()) + eq_(result.first(), (1, "some data")) + + @requirements.create_table + @util.provide_metadata + def test_create_table(self): + table = self._simple_fixture() + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.create_table + @requirements.schemas + @util.provide_metadata + def test_create_table_schema(self): + table = self._simple_fixture(schema=config.test_schema) + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.drop_table + @util.provide_metadata + def test_drop_table(self): + table = self._simple_fixture() + table.create(config.db, checkfirst=False) + table.drop(config.db, checkfirst=False) + + @requirements.create_table + @util.provide_metadata + def test_underscore_names(self): + table = self._underscore_fixture() + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.comment_reflection + @util.provide_metadata + def test_add_table_comment(self, connection): + table = self._simple_fixture() + table.create(connection, checkfirst=False) + table.comment = "a comment" + connection.execute(schema.SetTableComment(table)) + eq_( + inspect(connection).get_table_comment("test_table"), + {"text": "a comment"}, + ) + + @requirements.comment_reflection + @util.provide_metadata + def test_drop_table_comment(self, connection): + table = self._simple_fixture() + table.create(connection, checkfirst=False) + table.comment = "a comment" + connection.execute(schema.SetTableComment(table)) + connection.execute(schema.DropTableComment(table)) + eq_( + inspect(connection).get_table_comment("test_table"), {"text": None} + ) + + @requirements.table_ddl_if_exists + @util.provide_metadata + def test_create_table_if_not_exists(self, connection): + table = self._simple_fixture() + + connection.execute(schema.CreateTable(table, if_not_exists=True)) + + is_true(inspect(connection).has_table("test_table")) + connection.execute(schema.CreateTable(table, if_not_exists=True)) + + @requirements.index_ddl_if_exists + @util.provide_metadata + def test_create_index_if_not_exists(self, connection): + table, idx = self._table_index_fixture() + + connection.execute(schema.CreateTable(table, if_not_exists=True)) + is_true(inspect(connection).has_table("test_table")) + is_false( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.CreateIndex(idx, if_not_exists=True)) + + is_true( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.CreateIndex(idx, if_not_exists=True)) + + @requirements.table_ddl_if_exists + @util.provide_metadata + def test_drop_table_if_exists(self, connection): + table = self._simple_fixture() + + table.create(connection) + + is_true(inspect(connection).has_table("test_table")) + + connection.execute(schema.DropTable(table, if_exists=True)) + + is_false(inspect(connection).has_table("test_table")) + + connection.execute(schema.DropTable(table, if_exists=True)) + + @requirements.index_ddl_if_exists + @util.provide_metadata + def test_drop_index_if_exists(self, connection): + table, idx = self._table_index_fixture() + + table.create(connection) + + is_true( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.DropIndex(idx, if_exists=True)) + + is_false( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.DropIndex(idx, if_exists=True)) + + +class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest): + pass + + +class LongNameBlowoutTest(fixtures.TestBase): + """test the creation of a variety of DDL structures and ensure + label length limits pass on backends + + """ + + __backend__ = True + + def fk(self, metadata, connection): + convention = { + "fk": "foreign_key_%(table_name)s_" + "%(column_0_N_name)s_" + "%(referred_table_name)s_" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(20)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + test_needs_fk=True, + ) + + cons = ForeignKeyConstraint( + ["aid"], ["a_things_with_stuff.id_long_column_name"] + ) + Table( + "b_related_things_of_value", + metadata, + Column( + "aid", + ), + cons, + test_needs_fk=True, + ) + actual_name = cons.name + + metadata.create_all(connection) + + if testing.requires.foreign_key_constraint_name_reflection.enabled: + insp = inspect(connection) + fks = insp.get_foreign_keys("b_related_things_of_value") + reflected_name = fks[0]["name"] + + return actual_name, reflected_name + else: + return actual_name, None + + def pk(self, metadata, connection): + convention = { + "pk": "primary_key_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + a = Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer, primary_key=True), + ) + cons = a.primary_key + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + pk = insp.get_pk_constraint("a_things_with_stuff") + reflected_name = pk["name"] + return actual_name, reflected_name + + def ix(self, metadata, connection): + convention = { + "ix": "index_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + a = Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer), + ) + cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + ix = insp.get_indexes("a_things_with_stuff") + reflected_name = ix[0]["name"] + return actual_name, reflected_name + + def uq(self, metadata, connection): + convention = { + "uq": "unique_constraint_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + cons = UniqueConstraint("id_long_column_name", "id_another_long_name") + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer), + cons, + ) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + uq = insp.get_unique_constraints("a_things_with_stuff") + reflected_name = uq[0]["name"] + return actual_name, reflected_name + + def ck(self, metadata, connection): + convention = { + "ck": "check_constraint_%(table_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + cons = CheckConstraint("some_long_column_name > 5") + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("some_long_column_name", Integer), + cons, + ) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + ck = insp.get_check_constraints("a_things_with_stuff") + reflected_name = ck[0]["name"] + return actual_name, reflected_name + + @testing.combinations( + ("fk",), + ("pk",), + ("ix",), + ("ck", testing.requires.check_constraint_reflection.as_skips()), + ("uq", testing.requires.unique_constraint_reflection.as_skips()), + argnames="type_", + ) + def test_long_convention_name(self, type_, metadata, connection): + actual_name, reflected_name = getattr(self, type_)( + metadata, connection + ) + + assert len(actual_name) > 255 + + if reflected_name is not None: + overlap = actual_name[0 : len(reflected_name)] + if len(overlap) < len(actual_name): + eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5]) + else: + eq_(overlap, reflected_name) + + +__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest") diff --git a/lib/sqlalchemy/testing/suite/test_deprecations.py b/lib/sqlalchemy/testing/suite/test_deprecations.py new file mode 100644 index 0000000..b36162f --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_deprecations.py @@ -0,0 +1,145 @@ +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import select +from ... import testing +from ... import union + + +class DeprecatedCompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + def _assert_result(self, conn, select, result, params=()): + eq_(conn.execute(select, params).fetchall(), result) + + def test_plain_union(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + # note we've had to remove one use case entirely, which is this + # one. the Select gets its FROMS from the WHERE clause and the + # columns clause, but not the ORDER BY, which means the old ".c" system + # allowed you to "order_by(s.c.foo)" to get an unnamed column in the + # ORDER BY without adding the SELECT into the FROM and breaking the + # query. Users will have to adjust for this use case if they were doing + # it before. + def _dont_test_select_from_plain_union(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.order_by_col_from_union + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = ( + select(table) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select(table) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py new file mode 100644 index 0000000..c2c17d0 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -0,0 +1,361 @@ +#! coding: utf-8 + +from . import testing +from .. import assert_raises +from .. import config +from .. import engines +from .. import eq_ +from .. import fixtures +from .. import ne_ +from .. import provide_metadata +from ..config import requirements +from ..provision import set_default_schema_on_connection +from ..schema import Column +from ..schema import Table +from ... import bindparam +from ... import event +from ... import exc +from ... import Integer +from ... import literal_column +from ... import select +from ... import String +from ...util import compat + + +class ExceptionTest(fixtures.TablesTest): + """Test basic exception wrapping. + + DBAPIs vary a lot in exception behavior so to actually anticipate + specific exceptions from real round trips, we need to be conservative. + + """ + + run_deletes = "each" + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + @requirements.duplicate_key_raises_integrity_error + def test_integrity_error(self): + + with config.db.connect() as conn: + + trans = conn.begin() + conn.execute( + self.tables.manual_pk.insert(), {"id": 1, "data": "d1"} + ) + + assert_raises( + exc.IntegrityError, + conn.execute, + self.tables.manual_pk.insert(), + {"id": 1, "data": "d1"}, + ) + + trans.rollback() + + def test_exception_with_non_ascii(self): + with config.db.connect() as conn: + try: + # try to create an error message that likely has non-ascii + # characters in the DBAPI's message string. unfortunately + # there's no way to make this happen with some drivers like + # mysqlclient, pymysql. this at least does produce a non- + # ascii error message for cx_oracle, psycopg2 + conn.execute(select(literal_column(u"méil"))) + assert False + except exc.DBAPIError as err: + err_str = str(err) + + assert str(err.orig) in str(err) + + # test that we are actually getting string on Py2k, unicode + # on Py3k. + if compat.py2k: + assert isinstance(err_str, str) + else: + assert isinstance(err_str, str) + + +class IsolationLevelTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("isolation_level",) + + def _get_non_default_isolation_level(self): + levels = requirements.get_isolation_levels(config) + + default = levels["default"] + supported = levels["supported"] + + s = set(supported).difference(["AUTOCOMMIT", default]) + if s: + return s.pop() + else: + config.skip_test("no non-default isolation level available") + + def test_default_isolation_level(self): + eq_( + config.db.dialect.default_isolation_level, + requirements.get_isolation_levels(config)["default"], + ) + + def test_non_default_isolation_level(self): + non_default = self._get_non_default_isolation_level() + + with config.db.connect() as conn: + existing = conn.get_isolation_level() + + ne_(existing, non_default) + + conn.execution_options(isolation_level=non_default) + + eq_(conn.get_isolation_level(), non_default) + + conn.dialect.reset_isolation_level(conn.connection) + + eq_(conn.get_isolation_level(), existing) + + def test_all_levels(self): + levels = requirements.get_isolation_levels(config) + + all_levels = levels["supported"] + + for level in set(all_levels).difference(["AUTOCOMMIT"]): + with config.db.connect() as conn: + conn.execution_options(isolation_level=level) + + eq_(conn.get_isolation_level(), level) + + trans = conn.begin() + trans.rollback() + + eq_(conn.get_isolation_level(), level) + + with config.db.connect() as conn: + eq_( + conn.get_isolation_level(), + levels["default"], + ) + + +class AutocommitIsolationTest(fixtures.TablesTest): + + run_deletes = "each" + + __requires__ = ("autocommit",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + test_needs_acid=True, + ) + + def _test_conn_autocommits(self, conn, autocommit): + trans = conn.begin() + conn.execute( + self.tables.some_table.insert(), {"id": 1, "data": "some data"} + ) + trans.rollback() + + eq_( + conn.scalar(select(self.tables.some_table.c.id)), + 1 if autocommit else None, + ) + + with conn.begin(): + conn.execute(self.tables.some_table.delete()) + + def test_autocommit_on(self, connection_no_trans): + conn = connection_no_trans + c2 = conn.execution_options(isolation_level="AUTOCOMMIT") + self._test_conn_autocommits(c2, True) + + c2.dialect.reset_isolation_level(c2.connection) + + self._test_conn_autocommits(conn, False) + + def test_autocommit_off(self, connection_no_trans): + conn = connection_no_trans + self._test_conn_autocommits(conn, False) + + def test_turn_autocommit_off_via_default_iso_level( + self, connection_no_trans + ): + conn = connection_no_trans + conn = conn.execution_options(isolation_level="AUTOCOMMIT") + self._test_conn_autocommits(conn, True) + + conn.execution_options( + isolation_level=requirements.get_isolation_levels(config)[ + "default" + ] + ) + self._test_conn_autocommits(conn, False) + + +class EscapingTest(fixtures.TestBase): + @provide_metadata + def test_percent_sign_round_trip(self): + """test that the DBAPI accommodates for escaped / nonescaped + percent signs in a way that matches the compiler + + """ + m = self.metadata + t = Table("t", m, Column("data", String(50))) + t.create(config.db) + with config.db.begin() as conn: + conn.execute(t.insert(), dict(data="some % value")) + conn.execute(t.insert(), dict(data="some %% other value")) + + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some % value'") + ) + ), + "some % value", + ) + + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", + ) + + +class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("default_schema_name_switch",) + + def test_control_case(self): + default_schema_name = config.db.dialect.default_schema_name + + eng = engines.testing_engine() + with eng.connect(): + pass + + eq_(eng.dialect.default_schema_name, default_schema_name) + + def test_wont_work_wo_insert(self): + default_schema_name = config.db.dialect.default_schema_name + + eng = engines.testing_engine() + + @event.listens_for(eng, "connect") + def on_connect(dbapi_connection, connection_record): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, default_schema_name) + + def test_schema_change_on_connect(self): + eng = engines.testing_engine() + + @event.listens_for(eng, "connect", insert=True) + def on_connect(dbapi_connection, connection_record): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, config.test_schema) + + def test_schema_change_works_w_transactions(self): + eng = engines.testing_engine() + + @event.listens_for(eng, "connect", insert=True) + def on_connect(dbapi_connection, *arg): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + trans = conn.begin() + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + trans.rollback() + + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, config.test_schema) + + +class FutureWeCanSetDefaultSchemaWEventsTest( + fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest +): + pass + + +class DifficultParametersTest(fixtures.TestBase): + __backend__ = True + + @testing.combinations( + ("boring",), + ("per cent",), + ("per % cent",), + ("%percent",), + ("par(ens)",), + ("percent%(ens)yah",), + ("col:ons",), + ("more :: %colons%",), + ("/slashes/",), + ("more/slashes",), + ("q?marks",), + ("1param",), + ("1col:on",), + argnames="name", + ) + def test_round_trip(self, name, connection, metadata): + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column(name, String(50), nullable=False), + ) + + # table is created + t.create(connection) + + # automatic param generated by insert + connection.execute(t.insert().values({"id": 1, name: "some name"})) + + # automatic param generated by criteria, plus selecting the column + stmt = select(t.c[name]).where(t.c[name] == "some name") + + eq_(connection.scalar(stmt), "some name") + + # use the name in a param explicitly + stmt = select(t.c[name]).where(t.c[name] == bindparam(name)) + + row = connection.execute(stmt, {name: "some name"}).first() + + # name works as the key from cursor.description + eq_(row._mapping[name], "some name") diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py new file mode 100644 index 0000000..3c22f50 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -0,0 +1,367 @@ +from .. import config +from .. import engines +from .. import fixtures +from ..assertions import eq_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import literal +from ... import literal_column +from ... import select +from ... import String + + +class LastrowidTest(fixtures.TablesTest): + run_deletes = "each" + + __backend__ = True + + __requires__ = "implements_get_lastrowid", "autoincrement_insert" + + __engine_options__ = {"implicit_returning": False} + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + ( + conn.dialect.default_sequence_base, + "some data", + ), + ) + + def test_autoincrement_on_insert(self, connection): + + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.autoinc_pk, connection) + + def test_last_inserted_id(self, connection): + + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(r.inserted_primary_key, (pk,)) + + @requirements.dbapi_lastrowid + def test_native_lastrowid_autoinc(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + lastrowid = r.lastrowid + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(lastrowid, pk) + + +class InsertBehaviorTest(fixtures.TablesTest): + run_deletes = "each" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + Table( + "includes_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("x", Integer, default=5), + Column( + "y", + Integer, + default=literal_column("2", type_=Integer) + literal(2), + ), + ) + + @requirements.autoincrement_insert + def test_autoclose_on_insert(self): + if requirements.returning.enabled: + engine = engines.testing_engine( + options={"implicit_returning": False} + ) + else: + engine = config.db + + with engine.begin() as conn: + r = conn.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment + # an insert where the PK was taken from a row that the dialect + # selected, as is the case for mssql/pyodbc, will still report + # returns_rows as true because there's a cursor description. in that + # case, the row had to have been consumed at least. + assert not r.returns_rows or r.fetchone() is None + + @requirements.returning + def test_autoclose_on_insert_implicit_returning(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # note we are experimenting with having this be True + # as of I8091919d45421e3f53029b8660427f844fee0228 . + # implicit returning has fetched the row, but it still is a + # "returns rows" + assert r.returns_rows + + # and we should be able to fetchone() on it, we just get no row + eq_(r.fetchone(), None) + + # and the keys, etc. + eq_(r.keys(), ["id"]) + + # but the dialect took in the row already. not really sure + # what the best behavior is. + + @requirements.empty_inserts + def test_empty_insert(self, connection): + r = connection.execute(self.tables.autoinc_pk.insert()) + assert r._soft_closed + assert not r.closed + + r = connection.execute( + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) + ) + eq_(len(r.all()), 1) + + @requirements.empty_inserts_executemany + def test_empty_insert_multiple(self, connection): + r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}]) + assert r._soft_closed + assert not r.closed + + r = connection.execute( + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) + ) + + eq_(len(r.all()), 3) + + @requirements.insert_from_select + def test_insert_from_select_autoinc(self, connection): + src_table = self.tables.manual_pk + dest_table = self.tables.autoinc_pk + connection.execute( + src_table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + result = connection.execute( + dest_table.insert().from_select( + ("data",), + select(src_table.c.data).where( + src_table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_(result.inserted_primary_key, (None,)) + + result = connection.execute( + select(dest_table.c.data).order_by(dest_table.c.data) + ) + eq_(result.fetchall(), [("data2",), ("data3",)]) + + @requirements.insert_from_select + def test_insert_from_select_autoinc_no_rows(self, connection): + src_table = self.tables.manual_pk + dest_table = self.tables.autoinc_pk + + result = connection.execute( + dest_table.insert().from_select( + ("data",), + select(src_table.c.data).where( + src_table.c.data.in_(["data2", "data3"]) + ), + ) + ) + eq_(result.inserted_primary_key, (None,)) + + result = connection.execute( + select(dest_table.c.data).order_by(dest_table.c.data) + ) + + eq_(result.fetchall(), []) + + @requirements.insert_from_select + def test_insert_from_select(self, connection): + table = self.tables.manual_pk + connection.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + connection.execute( + table.insert() + .inline() + .from_select( + ("id", "data"), + select(table.c.id + 5, table.c.data).where( + table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_( + connection.execute( + select(table.c.data).order_by(table.c.data) + ).fetchall(), + [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)], + ) + + @requirements.insert_from_select + def test_insert_from_select_with_defaults(self, connection): + table = self.tables.includes_defaults + connection.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + connection.execute( + table.insert() + .inline() + .from_select( + ("id", "data"), + select(table.c.id + 5, table.c.data).where( + table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_( + connection.execute( + select(table).order_by(table.c.data, table.c.id) + ).fetchall(), + [ + (1, "data1", 5, 4), + (2, "data2", 5, 4), + (7, "data2", 5, 4), + (3, "data3", 5, 4), + (8, "data3", 5, 4), + ], + ) + + +class ReturningTest(fixtures.TablesTest): + run_create_tables = "each" + __requires__ = "returning", "autoincrement_insert" + __backend__ = True + + __engine_options__ = {"implicit_returning": True} + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + ( + conn.dialect.default_sequence_base, + "some data", + ), + ) + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + @requirements.fetch_rows_post_commit + def test_explicit_returning_pk_autocommit(self, connection): + table = self.tables.autoinc_pk + r = connection.execute( + table.insert().returning(table.c.id), dict(data="some data") + ) + pk = r.first()[0] + fetched_pk = connection.scalar(select(table.c.id)) + eq_(fetched_pk, pk) + + def test_explicit_returning_pk_no_autocommit(self, connection): + table = self.tables.autoinc_pk + r = connection.execute( + table.insert().returning(table.c.id), dict(data="some data") + ) + pk = r.first()[0] + fetched_pk = connection.scalar(select(table.c.id)) + eq_(fetched_pk, pk) + + def test_autoincrement_on_insert_implicit_returning(self, connection): + + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.autoinc_pk, connection) + + def test_last_inserted_id_implicit_returning(self, connection): + + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(r.inserted_primary_key, (pk,)) + + +__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py new file mode 100644 index 0000000..459a4d8 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -0,0 +1,1738 @@ +import operator +import re + +import sqlalchemy as sa +from .. import config +from .. import engines +from .. import eq_ +from .. import expect_warnings +from .. import fixtures +from .. import is_ +from ..provision import get_temp_table_name +from ..provision import temp_table_keyword_args +from ..schema import Column +from ..schema import Table +from ... import event +from ... import ForeignKey +from ... import func +from ... import Identity +from ... import inspect +from ... import Integer +from ... import MetaData +from ... import String +from ... import testing +from ... import types as sql_types +from ...schema import DDL +from ...schema import Index +from ...sql.elements import quoted_name +from ...sql.schema import BLANK_SCHEMA +from ...testing import is_false +from ...testing import is_true + + +metadata, users = None, None + + +class HasTableTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + if testing.requires.schemas.enabled: + Table( + "test_table_s", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + schema=config.test_schema, + ) + + def test_has_table(self): + with config.db.begin() as conn: + is_true(config.db.dialect.has_table(conn, "test_table")) + is_false(config.db.dialect.has_table(conn, "test_table_s")) + is_false(config.db.dialect.has_table(conn, "nonexistent_table")) + + @testing.requires.schemas + def test_has_table_schema(self): + with config.db.begin() as conn: + is_false( + config.db.dialect.has_table( + conn, "test_table", schema=config.test_schema + ) + ) + is_true( + config.db.dialect.has_table( + conn, "test_table_s", schema=config.test_schema + ) + ) + is_false( + config.db.dialect.has_table( + conn, "nonexistent_table", schema=config.test_schema + ) + ) + + +class HasIndexTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Index("my_idx", tt.c.data) + + if testing.requires.schemas.enabled: + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + schema=config.test_schema, + ) + Index("my_idx_s", tt.c.data) + + def test_has_index(self): + with config.db.begin() as conn: + assert config.db.dialect.has_index(conn, "test_table", "my_idx") + assert not config.db.dialect.has_index( + conn, "test_table", "my_idx_s" + ) + assert not config.db.dialect.has_index( + conn, "nonexistent_table", "my_idx" + ) + assert not config.db.dialect.has_index( + conn, "test_table", "nonexistent_idx" + ) + + @testing.requires.schemas + def test_has_index_schema(self): + with config.db.begin() as conn: + assert config.db.dialect.has_index( + conn, "test_table", "my_idx_s", schema=config.test_schema + ) + assert not config.db.dialect.has_index( + conn, "test_table", "my_idx", schema=config.test_schema + ) + assert not config.db.dialect.has_index( + conn, + "nonexistent_table", + "my_idx_s", + schema=config.test_schema, + ) + assert not config.db.dialect.has_index( + conn, + "test_table", + "nonexistent_idx_s", + schema=config.test_schema, + ) + + +class QuotedNameArgumentTest(fixtures.TablesTest): + run_create_tables = "once" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "quote ' one", + metadata, + Column("id", Integer), + Column("name", String(50)), + Column("data", String(50)), + Column("related_id", Integer), + sa.PrimaryKeyConstraint("id", name="pk quote ' one"), + sa.Index("ix quote ' one", "name"), + sa.UniqueConstraint( + "data", + name="uq quote' one", + ), + sa.ForeignKeyConstraint( + ["id"], ["related.id"], name="fk quote ' one" + ), + sa.CheckConstraint("name != 'foo'", name="ck quote ' one"), + comment=r"""quote ' one comment""", + test_needs_fk=True, + ) + + if testing.requires.symbol_names_w_double_quote.enabled: + Table( + 'quote " two', + metadata, + Column("id", Integer), + Column("name", String(50)), + Column("data", String(50)), + Column("related_id", Integer), + sa.PrimaryKeyConstraint("id", name='pk quote " two'), + sa.Index('ix quote " two', "name"), + sa.UniqueConstraint( + "data", + name='uq quote" two', + ), + sa.ForeignKeyConstraint( + ["id"], ["related.id"], name='fk quote " two' + ), + sa.CheckConstraint("name != 'foo'", name='ck quote " two '), + comment=r"""quote " two comment""", + test_needs_fk=True, + ) + + Table( + "related", + metadata, + Column("id", Integer, primary_key=True), + Column("related", Integer), + test_needs_fk=True, + ) + + if testing.requires.view_column_reflection.enabled: + + if testing.requires.symbol_names_w_double_quote.enabled: + names = [ + "quote ' one", + 'quote " two', + ] + else: + names = [ + "quote ' one", + ] + for name in names: + query = "CREATE VIEW %s AS SELECT * FROM %s" % ( + config.db.dialect.identifier_preparer.quote( + "view %s" % name + ), + config.db.dialect.identifier_preparer.quote(name), + ) + + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, + "before_drop", + DDL( + "DROP VIEW %s" + % config.db.dialect.identifier_preparer.quote( + "view %s" % name + ) + ), + ) + + def quote_fixtures(fn): + return testing.combinations( + ("quote ' one",), + ('quote " two', testing.requires.symbol_names_w_double_quote), + )(fn) + + @quote_fixtures + def test_get_table_options(self, name): + insp = inspect(config.db) + + insp.get_table_options(name) + + @quote_fixtures + @testing.requires.view_column_reflection + def test_get_view_definition(self, name): + insp = inspect(config.db) + assert insp.get_view_definition("view %s" % name) + + @quote_fixtures + def test_get_columns(self, name): + insp = inspect(config.db) + assert insp.get_columns(name) + + @quote_fixtures + def test_get_pk_constraint(self, name): + insp = inspect(config.db) + assert insp.get_pk_constraint(name) + + @quote_fixtures + def test_get_foreign_keys(self, name): + insp = inspect(config.db) + assert insp.get_foreign_keys(name) + + @quote_fixtures + def test_get_indexes(self, name): + insp = inspect(config.db) + assert insp.get_indexes(name) + + @quote_fixtures + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, name): + insp = inspect(config.db) + assert insp.get_unique_constraints(name) + + @quote_fixtures + @testing.requires.comment_reflection + def test_get_table_comment(self, name): + insp = inspect(config.db) + assert insp.get_table_comment(name) + + @quote_fixtures + @testing.requires.check_constraint_reflection + def test_get_check_constraints(self, name): + insp = inspect(config.db) + assert insp.get_check_constraints(name) + + +class ComponentReflectionTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + + @classmethod + def setup_bind(cls): + if config.requirements.independent_connections.enabled: + from sqlalchemy import pool + + return engines.testing_engine( + options=dict(poolclass=pool.StaticPool, scope="class"), + ) + else: + return config.db + + @classmethod + def define_tables(cls, metadata): + cls.define_reflected_tables(metadata, None) + if testing.requires.schemas.enabled: + cls.define_reflected_tables(metadata, testing.config.test_schema) + + @classmethod + def define_reflected_tables(cls, metadata, schema): + if schema: + schema_prefix = schema + "." + else: + schema_prefix = "" + + if testing.requires.self_referential_foreign_keys.enabled: + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + Column( + "parent_user_id", + sa.Integer, + sa.ForeignKey( + "%susers.user_id" % schema_prefix, name="user_id_fk" + ), + ), + schema=schema, + test_needs_fk=True, + ) + else: + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + schema=schema, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column("dingaling_id", sa.Integer, primary_key=True), + Column( + "address_id", + sa.Integer, + sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + ), + Column("data", sa.String(30)), + schema=schema, + test_needs_fk=True, + ) + Table( + "email_addresses", + metadata, + Column("address_id", sa.Integer), + Column( + "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) + ), + Column("email_address", sa.String(20)), + sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), + schema=schema, + test_needs_fk=True, + ) + Table( + "comment_test", + metadata, + Column("id", sa.Integer, primary_key=True, comment="id comment"), + Column("data", sa.String(20), comment="data % comment"), + Column( + "d2", + sa.String(20), + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + schema=schema, + comment=r"""the test % ' " \ table comment""", + ) + + if testing.requires.cross_schema_fk_reflection.enabled: + if schema is None: + Table( + "local_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), + Column( + "remote_id", + ForeignKey( + "%s.remote_table_2.id" % testing.config.test_schema + ), + ), + test_needs_fk=True, + schema=config.db.dialect.default_schema_name, + ) + else: + Table( + "remote_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column( + "local_id", + ForeignKey( + "%s.local_table.id" + % config.db.dialect.default_schema_name + ), + ), + Column("data", sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + Table( + "remote_table_2", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + + if testing.requires.index_reflection.enabled: + cls.define_index(metadata, users) + + if not schema: + # test_needs_fk is at the moment to force MySQL InnoDB + noncol_idx_test_nopk = Table( + "noncol_idx_test_nopk", + metadata, + Column("q", sa.String(5)), + test_needs_fk=True, + ) + + noncol_idx_test_pk = Table( + "noncol_idx_test_pk", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("q", sa.String(5)), + test_needs_fk=True, + ) + + if testing.requires.indexes_with_ascdesc.enabled: + Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) + Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) + + if testing.requires.view_column_reflection.enabled: + cls.define_views(metadata, schema) + if not schema and testing.requires.temp_table_reflection.enabled: + cls.define_temp_tables(metadata) + + @classmethod + def define_temp_tables(cls, metadata): + kw = temp_table_keyword_args(config, config.db) + table_name = get_temp_table_name( + config, config.db, "user_tmp_%s" % config.ident + ) + user_tmp = Table( + table_name, + metadata, + Column("id", sa.INT, primary_key=True), + Column("name", sa.VARCHAR(50)), + Column("foo", sa.INT), + # disambiguate temp table unique constraint names. this is + # pretty arbitrary for a generic dialect however we are doing + # it to suit SQL Server which will produce name conflicts for + # unique constraints created against temp tables in different + # databases. + # https://www.arbinada.com/en/node/1645 + sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident), + sa.Index("user_tmp_ix", "foo"), + **kw + ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): + event.listen( + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp_%s" % config.ident + ), + ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) + + @classmethod + def define_index(cls, metadata, users): + Index("users_t_idx", users.c.test1, users.c.test2) + Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1) + + @classmethod + def define_views(cls, metadata, schema): + for table_name in ("users", "email_addresses"): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + "_v" + query = "CREATE VIEW %s AS SELECT * FROM %s" % ( + view_name, + fullname, + ) + + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, "before_drop", DDL("DROP VIEW %s" % view_name) + ) + + @testing.requires.schema_reflection + def test_get_schema_names(self): + insp = inspect(self.bind) + + self.assert_(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_get_schema_names_w_translate_map(self, connection): + """test #7300""" + + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + self.assert_(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_dialect_initialize(self): + engine = engines.testing_engine() + inspect(engine) + assert hasattr(engine.dialect, "default_schema_name") + + @testing.requires.schema_reflection + def test_get_default_schema_name(self): + insp = inspect(self.bind) + eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + + @testing.requires.foreign_key_constraint_reflection + @testing.combinations( + (None, True, False, False), + (None, True, False, True, testing.requires.schemas), + ("foreign_key", True, False, False), + (None, False, True, False), + (None, False, True, True, testing.requires.schemas), + (None, True, True, False), + (None, True, True, True, testing.requires.schemas), + argnames="order_by,include_plain,include_views,use_schema", + ) + def test_get_table_names( + self, connection, order_by, include_plain, include_views, use_schema + ): + + if use_schema: + schema = config.test_schema + else: + schema = None + + _ignore_tables = [ + "comment_test", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", + ] + + insp = inspect(connection) + + if include_views: + table_names = insp.get_view_names(schema) + table_names.sort() + answer = ["email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) + + if include_plain: + if order_by: + tables = [ + rec[0] + for rec in insp.get_sorted_table_and_fkc_names(schema) + if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] + + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] + eq_(table_names, answer) + else: + answer = ["dingalings", "email_addresses", "users"] + eq_(sorted(table_names), answer) + + @testing.requires.temp_table_names + def test_get_temp_table_names(self): + insp = inspect(self.bind) + temp_table_names = insp.get_temp_table_names() + eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident]) + + @testing.requires.view_reflection + @testing.requires.temp_table_names + @testing.requires.temporary_views + def test_get_temp_view_names(self): + insp = inspect(self.bind) + temp_table_names = insp.get_temp_view_names() + eq_(sorted(temp_table_names), ["user_tmp_v"]) + + @testing.requires.comment_reflection + def test_get_comments(self): + self._test_get_comments() + + @testing.requires.comment_reflection + @testing.requires.schemas + def test_get_comments_with_schema(self): + self._test_get_comments(testing.config.test_schema) + + def _test_get_comments(self, schema=None): + insp = inspect(self.bind) + + eq_( + insp.get_table_comment("comment_test", schema=schema), + {"text": r"""the test % ' " \ table comment"""}, + ) + + eq_(insp.get_table_comment("users", schema=schema), {"text": None}) + + eq_( + [ + {"name": rec["name"], "comment": rec["comment"]} + for rec in insp.get_columns("comment_test", schema=schema) + ], + [ + {"comment": "id comment", "name": "id"}, + {"comment": "data % comment", "name": "data"}, + { + "comment": ( + r"""Comment types type speedily ' " \ '' Fun!""" + ), + "name": "d2", + }, + ], + ) + + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False, testing.requires.view_reflection), + ( + True, + True, + testing.requires.schemas + testing.requires.view_reflection, + ), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + if use_views: + table_names = ["users_v", "email_addresses_v"] + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) + for table_name, table in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + self.assert_(len(cols) > 0, len(cols)) + + # should be in order + + for i, col in enumerate(table.columns): + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + + if testing.against("oracle") and ctype_def in ( + sql_types.Date, + sql_types.DateTime, + ): + ctype_def = sql_types.Date + + # assert that the desired type and return type share + # a base within one of the generic types. + + self.assert_( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" + % (col.name, col.type, cols[i]["name"], ctype), + ) + + if not col.primary_key: + assert cols[i]["default"] is None + + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self): + table_name = get_temp_table_name( + config, self.bind, "user_tmp_%s" % config.ident + ) + user_tmp = self.tables[table_name] + insp = inspect(self.bind) + cols = insp.get_columns(table_name) + self.assert_(len(cols) > 0, len(cols)) + + for i, col in enumerate(user_tmp.columns): + eq_(col.name, cols[i]["name"]) + + @testing.requires.temp_table_reflection + @testing.requires.view_column_reflection + @testing.requires.temporary_views + def test_get_temp_view_columns(self): + insp = inspect(self.bind) + cols = insp.get_columns("user_tmp_v") + eq_([col["name"] for col in cols], ["id", "name", "foo"]) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.primary_key_constraint_reflection + def test_get_pk_constraint(self, connection, use_schema): + if use_schema: + schema = testing.config.test_schema + else: + schema = None + + users, addresses = self.tables.users, self.tables.email_addresses + insp = inspect(connection) + + users_cons = insp.get_pk_constraint(users.name, schema=schema) + users_pkeys = users_cons["constrained_columns"] + eq_(users_pkeys, ["user_id"]) + + addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) + addr_pkeys = addr_cons["constrained_columns"] + eq_(addr_pkeys, ["address_id"]) + + with testing.requires.reflects_pk_names.fail_if(): + eq_(addr_cons["name"], "email_ad_pk") + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + insp = inspect(connection) + expected_schema = schema + # users + + if testing.requires.self_referential_foreign_keys.enabled: + users_fkeys = insp.get_foreign_keys(users.name, schema=schema) + fkey1 = users_fkeys[0] + + with testing.requires.named_constraints.fail_if(): + eq_(fkey1["name"], "user_id_fk") + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + if testing.requires.self_referential_foreign_keys.enabled: + eq_(fkey1["constrained_columns"], ["parent_user_id"]) + + # addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) + fkey1 = addr_fkeys[0] + + with testing.requires.implicitly_named_constraints.fail_if(): + self.assert_(fkey1["name"] is not None) + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["remote_user_id"]) + + @testing.requires.cross_schema_fk_reflection + @testing.requires.schemas + def test_get_inter_schema_foreign_keys(self): + local_table, remote_table, remote_table_2 = self.tables( + "%s.local_table" % self.bind.dialect.default_schema_name, + "%s.remote_table" % testing.config.test_schema, + "%s.remote_table_2" % testing.config.test_schema, + ) + + insp = inspect(self.bind) + + local_fkeys = insp.get_foreign_keys(local_table.name) + eq_(len(local_fkeys), 1) + + fkey1 = local_fkeys[0] + eq_(fkey1["referred_schema"], testing.config.test_schema) + eq_(fkey1["referred_table"], remote_table_2.name) + eq_(fkey1["referred_columns"], ["id"]) + eq_(fkey1["constrained_columns"], ["remote_id"]) + + remote_fkeys = insp.get_foreign_keys( + remote_table.name, schema=testing.config.test_schema + ) + eq_(len(remote_fkeys), 1) + + fkey2 = remote_fkeys[0] + + assert fkey2["referred_schema"] in ( + None, + self.bind.dialect.default_schema_name, + ) + eq_(fkey2["referred_table"], local_table.name) + eq_(fkey2["referred_columns"], ["id"]) + eq_(fkey2["constrained_columns"], ["local_id"]) + + def _assert_insp_indexes(self, indexes, expected_indexes): + index_names = [d["name"] for d in indexes] + for e_index in expected_indexes: + assert e_index["name"] in index_names + index = indexes[index_names.index(e_index["name"])] + for key in e_index: + eq_(e_index[key], index[key]) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_indexes(self, connection, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None + + # The database may decide to create indexes for foreign keys, etc. + # so there may be more indexes than expected. + insp = inspect(self.bind) + indexes = insp.get_indexes("users", schema=schema) + expected_indexes = [ + { + "unique": False, + "column_names": ["test1", "test2"], + "name": "users_t_idx", + }, + { + "unique": False, + "column_names": ["user_id", "test2", "test1"], + "name": "users_all_idx", + }, + ] + self._assert_insp_indexes(indexes, expected_indexes) + + @testing.combinations( + ("noncol_idx_test_nopk", "noncol_idx_nopk"), + ("noncol_idx_test_pk", "noncol_idx_pk"), + argnames="tname,ixname", + ) + @testing.requires.index_reflection + @testing.requires.indexes_with_ascdesc + def test_get_noncol_index(self, connection, tname, ixname): + insp = inspect(connection) + indexes = insp.get_indexes(tname) + + # reflecting an index that has "x DESC" in it as the column. + # the DB may or may not give us "x", but make sure we get the index + # back, it has a name, it's connected to the table. + expected_indexes = [{"unique": False, "name": ixname}] + self._assert_insp_indexes(indexes, expected_indexes) + + t = Table(tname, MetaData(), autoload_with=connection) + eq_(len(t.indexes), 1) + is_(list(t.indexes)[0].table, t) + eq_(list(t.indexes)[0].name, ixname) + + @testing.requires.temp_table_reflection + @testing.requires.unique_constraint_reflection + def test_get_temp_table_unique_constraints(self): + insp = inspect(self.bind) + reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident) + for refl in reflected: + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + refl.pop("duplicates_index", None) + eq_( + reflected, + [ + { + "column_names": ["name"], + "name": "user_tmp_uq_%s" % config.ident, + } + ], + ) + + @testing.requires.temp_table_reflect_indexes + def test_get_temp_table_indexes(self): + insp = inspect(self.bind) + table_name = get_temp_table_name( + config, config.db, "user_tmp_%s" % config.ident + ) + indexes = insp.get_indexes(table_name) + for ind in indexes: + ind.pop("dialect_options", None) + expected = [ + {"unique": False, "column_names": ["foo"], "name": "user_tmp_ix"} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + eq_( + [idx for idx in indexes if idx["name"] == "user_tmp_ix"], + expected, + ) + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, metadata, connection, use_schema): + # SQLite dialect needs to parse the names of the constraints + # separately from what it gets from PRAGMA index_list(), and + # then matches them up. so same set of column_names in two + # constraints will confuse it. Perhaps we should no longer + # bother with index_list() here since we have the whole + # CREATE TABLE? + + if use_schema: + schema = config.test_schema + else: + schema = None + uniques = sorted( + [ + {"name": "unique_a", "column_names": ["a"]}, + {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]}, + {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]}, + {"name": "unique_asc_key", "column_names": ["asc", "key"]}, + {"name": "i.have.dots", "column_names": ["b"]}, + {"name": "i have spaces", "column_names": ["c"]}, + ], + key=operator.itemgetter("name"), + ) + table = Table( + "testtbl", + metadata, + Column("a", sa.String(20)), + Column("b", sa.String(30)), + Column("c", sa.Integer), + # reserved identifiers + Column("asc", sa.String(30)), + Column("key", sa.String(30)), + schema=schema, + ) + for uc in uniques: + table.append_constraint( + sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) + ) + table.create(connection) + + inspector = inspect(connection) + reflected = sorted( + inspector.get_unique_constraints("testtbl", schema=schema), + key=operator.itemgetter("name"), + ) + + names_that_duplicate_index = set() + + for orig, refl in zip(uniques, reflected): + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + dupe = refl.pop("duplicates_index", None) + if dupe: + names_that_duplicate_index.add(dupe) + eq_(orig, refl) + + reflected_metadata = MetaData() + reflected = Table( + "testtbl", + reflected_metadata, + autoload_with=connection, + schema=schema, + ) + + # test "deduplicates for index" logic. MySQL and Oracle + # "unique constraints" are actually unique indexes (with possible + # exception of a unique that is a dupe of another one in the case + # of Oracle). make sure # they aren't duplicated. + idx_names = set([idx.name for idx in reflected.indexes]) + uq_names = set( + [ + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + ] + ).difference(["unique_c_a_b"]) + + assert not idx_names.intersection(uq_names) + if names_that_duplicate_index: + eq_(names_that_duplicate_index, idx_names) + eq_(uq_names, set()) + + @testing.requires.view_reflection + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_view_definition(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + view_name1 = "users_v" + view_name2 = "email_addresses_v" + insp = inspect(connection) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + + # why is this here if it's PG specific ? + @testing.combinations( + ("users", False), + ("users", True, testing.requires.schemas), + argnames="table_name,use_schema", + ) + @testing.only_on("postgresql", "PG specific feature") + def test_get_table_oid(self, connection, table_name, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + insp = inspect(connection) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, int)) + + @testing.requires.table_reflection + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + A backend is better off not returning "autoincrement" at all, + instead of potentially returning "False" for an auto-incrementing + primary key column. + + """ + + insp = inspect(self.bind) + + for tname, cname in [ + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), + ]: + cols = insp.get_columns(tname) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) + + +class TableNoColumnsTest(fixtures.TestBase): + __requires__ = ("reflect_tables_no_columns",) + __backend__ = True + + @testing.fixture + def table_no_columns(self, connection, metadata): + Table("empty", metadata) + metadata.create_all(connection) + + @testing.fixture + def view_no_columns(self, connection, metadata): + Table("empty", metadata) + metadata.create_all(connection) + + Table("empty", metadata) + event.listen( + metadata, + "after_create", + DDL("CREATE VIEW empty_v AS SELECT * FROM empty"), + ) + + # for transactional DDL the transaction is rolled back before this + # drop statement is invoked + event.listen( + metadata, "before_drop", DDL("DROP VIEW IF EXISTS empty_v") + ) + metadata.create_all(connection) + + @testing.requires.reflect_tables_no_columns + def test_reflect_table_no_columns(self, connection, table_no_columns): + t2 = Table("empty", MetaData(), autoload_with=connection) + eq_(list(t2.c), []) + + @testing.requires.reflect_tables_no_columns + def test_get_columns_table_no_columns(self, connection, table_no_columns): + eq_(inspect(connection).get_columns("empty"), []) + + @testing.requires.reflect_tables_no_columns + def test_reflect_incl_table_no_columns(self, connection, table_no_columns): + m = MetaData() + m.reflect(connection) + assert set(m.tables).intersection(["empty"]) + + @testing.requires.views + @testing.requires.reflect_tables_no_columns + def test_reflect_view_no_columns(self, connection, view_no_columns): + t2 = Table("empty_v", MetaData(), autoload_with=connection) + eq_(list(t2.c), []) + + @testing.requires.views + @testing.requires.reflect_tables_no_columns + def test_get_columns_view_no_columns(self, connection, view_no_columns): + eq_(inspect(connection).get_columns("empty_v"), []) + + +class ComponentReflectionTestExtra(fixtures.TestBase): + + __backend__ = True + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + @testing.requires.check_constraint_reflection + def test_get_check_constraints(self, metadata, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + Table( + "sa_cc", + metadata, + Column("a", Integer()), + sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), + sa.CheckConstraint( + "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing" + ), + schema=schema, + ) + + metadata.create_all(connection) + + inspector = inspect(connection) + reflected = sorted( + inspector.get_check_constraints("sa_cc", schema=schema), + key=operator.itemgetter("name"), + ) + + # trying to minimize effect of quoting, parenthesis, etc. + # may need to add more to this as new dialects get CHECK + # constraint reflection support + def normalize(sqltext): + return " ".join( + re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I) + ) + + reflected = [ + {"name": item["name"], "sqltext": normalize(item["sqltext"])} + for item in reflected + ] + eq_( + reflected, + [ + {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"}, + {"name": "cc1", "sqltext": "a > 1 and a < 5"}, + ], + ) + + @testing.requires.indexes_with_expressions + def test_reflect_expression_based_indexes(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + + Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) + + Index("t_idx_2", t.c.x) + + metadata.create_all(connection) + + insp = inspect(connection) + + expected = [ + {"name": "t_idx_2", "column_names": ["x"], "unique": False} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + expected[0]["dialect_options"] = { + "%s_include" % connection.engine.name: [] + } + + with expect_warnings( + "Skipped unsupported reflection of expression-based index t_idx" + ): + eq_( + insp.get_indexes("t"), + expected, + ) + + @testing.requires.index_reflects_included_columns + def test_reflect_covering_index(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + idx = Index("t_idx", t.c.x) + idx.dialect_options[connection.engine.name]["include"] = ["y"] + + metadata.create_all(connection) + + insp = inspect(connection) + + eq_( + insp.get_indexes("t"), + [ + { + "name": "t_idx", + "column_names": ["x"], + "include_columns": ["y"], + "unique": False, + "dialect_options": { + "%s_include" % connection.engine.name: ["y"] + }, + } + ], + ) + + t2 = Table("t", MetaData(), autoload_with=connection) + eq_( + list(t2.indexes)[0].dialect_options[connection.engine.name][ + "include" + ], + ["y"], + ) + + def _type_round_trip(self, connection, metadata, *types): + t = Table( + "t", + metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) + t.create(connection) + + return [c["type"] for c in inspect(connection).get_columns("t")] + + @testing.requires.table_reflection + def test_numeric_reflection(self, connection, metadata): + for typ in self._type_round_trip( + connection, metadata, sql_types.Numeric(18, 5) + ): + assert isinstance(typ, sql_types.Numeric) + eq_(typ.precision, 18) + eq_(typ.scale, 5) + + @testing.requires.table_reflection + def test_varchar_reflection(self, connection, metadata): + typ = self._type_round_trip( + connection, metadata, sql_types.String(52) + )[0] + assert isinstance(typ, sql_types.String) + eq_(typ.length, 52) + + @testing.requires.table_reflection + def test_nullable_reflection(self, connection, metadata): + t = Table( + "t", + metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) + t.create(connection) + eq_( + dict( + (col["name"], col["nullable"]) + for col in inspect(connection).get_columns("t") + ), + {"a": True, "b": False}, + ) + + @testing.combinations( + ( + None, + "CASCADE", + None, + testing.requires.foreign_key_constraint_option_reflection_ondelete, + ), + ( + None, + None, + "SET NULL", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + None, + "NO ACTION", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + "NO ACTION", + None, + testing.requires.fk_constraint_option_reflection_ondelete_noaction, + ), + ( + None, + None, + "RESTRICT", + testing.requires.fk_constraint_option_reflection_onupdate_restrict, + ), + ( + None, + "RESTRICT", + None, + testing.requires.fk_constraint_option_reflection_ondelete_restrict, + ), + argnames="expected,ondelete,onupdate", + ) + def test_get_foreign_key_options( + self, connection, metadata, expected, ondelete, onupdate + ): + options = {} + if ondelete: + options["ondelete"] = ondelete + if onupdate: + options["onupdate"] = onupdate + + if expected is None: + expected = options + + Table( + "x", + metadata, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) + + Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + + # test 'options' is always present for a backend + # that can reflect these, since alembic looks for this + opts = insp.get_foreign_keys("table")[0]["options"] + + eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) + + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(opts, expected) + # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) + + +class NormalizedNameTest(fixtures.TablesTest): + __requires__ = ("denormalized_names",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + quoted_name("t1", quote=True), + metadata, + Column("id", Integer, primary_key=True), + ) + Table( + quoted_name("t2", quote=True), + metadata, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.id")), + ) + + def test_reflect_lowercase_forced_tables(self): + + m2 = MetaData() + t2_ref = Table( + quoted_name("t2", quote=True), m2, autoload_with=config.db + ) + t1_ref = m2.tables["t1"] + assert t2_ref.c.t1id.references(t1_ref.c.id) + + m3 = MetaData() + m3.reflect( + config.db, only=lambda name, m: name.lower() in ("t1", "t2") + ) + assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) + + def test_get_table_names(self): + tablenames = [ + t + for t in inspect(config.db).get_table_names() + if t.lower() in ("t1", "t2") + ] + + eq_(tablenames[0].upper(), tablenames[0].lower()) + eq_(tablenames[1].upper(), tablenames[1].lower()) + + +class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest): + def test_computed_col_default_not_set(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + col_data = {c["name"]: c for c in cols} + is_true("42" in col_data["with_default"]["default"]) + is_(col_data["normal"]["default"], None) + is_(col_data["computed_col"]["default"], None) + + def test_get_column_returns_computed(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + data = {c["name"]: c for c in cols} + for key in ("id", "normal", "with_default"): + is_true("computed" not in data[key]) + compData = data["computed_col"] + is_true("computed" in compData) + is_true("sqltext" in compData["computed"]) + eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42") + eq_( + "persisted" in compData["computed"], + testing.requires.computed_columns_reflect_persisted.enabled, + ) + if testing.requires.computed_columns_reflect_persisted.enabled: + eq_( + compData["computed"]["persisted"], + testing.requires.computed_columns_default_persisted.enabled, + ) + + def check_column(self, data, column, sqltext, persisted): + is_true("computed" in data[column]) + compData = data[column]["computed"] + eq_(self.normalize(compData["sqltext"]), sqltext) + if testing.requires.computed_columns_reflect_persisted.enabled: + is_true("persisted" in compData) + is_(compData["persisted"], persisted) + + def test_get_column_returns_persisted(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_column_table") + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal+42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal+2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal-42", + True, + ) + + @testing.requires.schemas + def test_get_column_returns_persisted_with_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns( + "computed_column_table", schema=config.test_schema + ) + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal/42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal/2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal*42", + True, + ) + + +class IdentityReflectionTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("identity_columns", "table_reflection") + + @classmethod + def define_tables(cls, metadata): + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity()), + ) + Table( + "t2", + metadata, + Column( + "id2", + Integer, + Identity( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + ), + ) + if testing.requires.schemas.enabled: + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity(always=True, start=20)), + schema=config.test_schema, + ) + + def check(self, value, exp, approx): + if testing.requires.identity_columns_standard.enabled: + common_keys = ( + "always", + "start", + "increment", + "minvalue", + "maxvalue", + "cycle", + "cache", + ) + for k in list(value): + if k not in common_keys: + value.pop(k) + if approx: + eq_(len(value), len(exp)) + for k in value: + if k == "minvalue": + is_true(value[k] <= exp[k]) + elif k in {"maxvalue", "cache"}: + is_true(value[k] >= exp[k]) + else: + eq_(value[k], exp[k], k) + else: + eq_(value, exp) + else: + eq_(value["start"], exp["start"]) + eq_(value["increment"], exp["increment"]) + + def test_reflect_identity(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1") + insp.get_columns("t2") + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=False, + start=1, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + elif col["name"] == "id2": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + approx=False, + ) + + @testing.requires.schemas + def test_reflect_identity_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1", schema=config.test_schema) + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=20, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + + +class CompositeKeyReflectionTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + tb1 = Table( + "tb1", + metadata, + Column("id", Integer), + Column("attr", Integer), + Column("name", sql_types.VARCHAR(20)), + sa.PrimaryKeyConstraint("name", "id", "attr", name="pk_tb1"), + schema=None, + test_needs_fk=True, + ) + Table( + "tb2", + metadata, + Column("id", Integer, primary_key=True), + Column("pid", Integer), + Column("pattr", Integer), + Column("pname", sql_types.VARCHAR(20)), + sa.ForeignKeyConstraint( + ["pname", "pid", "pattr"], + [tb1.c.name, tb1.c.id, tb1.c.attr], + name="fk_tb1_name_id_attr", + ), + schema=None, + test_needs_fk=True, + ) + + @testing.requires.primary_key_constraint_reflection + def test_pk_column_order(self): + # test for issue #5661 + insp = inspect(self.bind) + primary_key = insp.get_pk_constraint(self.tables.tb1.name) + eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) + + @testing.requires.foreign_key_constraint_reflection + def test_fk_column_order(self): + # test for issue #5661 + insp = inspect(self.bind) + foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) + eq_(len(foreign_keys), 1) + fkey1 = foreign_keys[0] + eq_(fkey1.get("referred_columns"), ["name", "id", "attr"]) + eq_(fkey1.get("constrained_columns"), ["pname", "pid", "pattr"]) + + +__all__ = ( + "ComponentReflectionTest", + "ComponentReflectionTestExtra", + "TableNoColumnsTest", + "QuotedNameArgumentTest", + "HasTableTest", + "HasIndexTest", + "NormalizedNameTest", + "ComputedReflectionTest", + "IdentityReflectionTest", + "CompositeKeyReflectionTest", +) diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py new file mode 100644 index 0000000..c41a550 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -0,0 +1,426 @@ +import datetime + +from .. import engines +from .. import fixtures +from ..assertions import eq_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import DateTime +from ... import func +from ... import Integer +from ... import select +from ... import sql +from ... import String +from ... import testing +from ... import text +from ... import util + + +class RowFetchTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Table( + "has_dates", + metadata, + Column("id", Integer, primary_key=True), + Column("today", DateTime), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + connection.execute( + cls.tables.has_dates.insert(), + [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}], + ) + + def test_via_attr(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row.id, 1) + eq_(row.data, "d1") + + def test_via_string(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row._mapping["id"], 1) + eq_(row._mapping["data"], "d1") + + def test_via_int(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row[0], 1) + eq_(row[1], "d1") + + def test_via_col_object(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row._mapping[self.tables.plain_pk.c.id], 1) + eq_(row._mapping[self.tables.plain_pk.c.data], "d1") + + @requirements.duplicate_names_in_cursor_description + def test_row_with_dupe_names(self, connection): + result = connection.execute( + select( + self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label("data"), + ).order_by(self.tables.plain_pk.c.id) + ) + row = result.first() + eq_(result.keys(), ["data", "data"]) + eq_(row, ("d1", "d1")) + + def test_row_w_scalar_select(self, connection): + """test that a scalar select as a column is returned as such + and that type conversion works OK. + + (this is half a SQLAlchemy Core test and half to catch database + backends that may have unusual behavior with scalar selects.) + + """ + datetable = self.tables.has_dates + s = select(datetable.alias("x").c.today).scalar_subquery() + s2 = select(datetable.c.id, s.label("somelabel")) + row = connection.execute(s2).first() + + eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0)) + + +class PercentSchemaNamesTest(fixtures.TablesTest): + """tests using percent signs, spaces in table and column names. + + This didn't work for PostgreSQL / MySQL drivers for a long time + but is now supported. + + """ + + __requires__ = ("percent_schema_names",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.tables.percent_table = Table( + "percent%table", + metadata, + Column("percent%", Integer), + Column("spaces % more spaces", Integer), + ) + cls.tables.lightweight_percent_table = sql.table( + "percent%table", + sql.column("percent%"), + sql.column("spaces % more spaces"), + ) + + def test_single_roundtrip(self, connection): + percent_table = self.tables.percent_table + for params in [ + {"percent%": 5, "spaces % more spaces": 12}, + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ]: + connection.execute(percent_table.insert(), params) + self._assert_table(connection) + + def test_executemany_roundtrip(self, connection): + percent_table = self.tables.percent_table + connection.execute( + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} + ) + connection.execute( + percent_table.insert(), + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], + ) + self._assert_table(connection) + + def _assert_table(self, conn): + percent_table = self.tables.percent_table + lightweight_percent_table = self.tables.lightweight_percent_table + + for table in ( + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias(), + ): + eq_( + list( + conn.execute(table.select().order_by(table.c["percent%"])) + ), + [(5, 12), (7, 11), (9, 10), (11, 9)], + ) + + eq_( + list( + conn.execute( + table.select() + .where(table.c["spaces % more spaces"].in_([9, 10])) + .order_by(table.c["percent%"]) + ) + ), + [(9, 10), (11, 9)], + ) + + row = conn.execute( + table.select().order_by(table.c["percent%"]) + ).first() + eq_(row._mapping["percent%"], 5) + eq_(row._mapping["spaces % more spaces"], 12) + + eq_(row._mapping[table.c["percent%"]], 5) + eq_(row._mapping[table.c["spaces % more spaces"]], 12) + + conn.execute( + percent_table.update().values( + {percent_table.c["spaces % more spaces"]: 15} + ) + ) + + eq_( + list( + conn.execute( + percent_table.select().order_by( + percent_table.c["percent%"] + ) + ) + ), + [(5, 15), (7, 15), (9, 15), (11, 15)], + ) + + +class ServerSideCursorsTest( + fixtures.TestBase, testing.AssertsExecutionResults +): + + __requires__ = ("server_side_cursors",) + + __backend__ = True + + def _is_server_side(self, cursor): + # TODO: this is a huge issue as it prevents these tests from being + # usable by third party dialects. + if self.engine.dialect.driver == "psycopg2": + return bool(cursor.name) + elif self.engine.dialect.driver == "pymysql": + sscursor = __import__("pymysql.cursors").cursors.SSCursor + return isinstance(cursor, sscursor) + elif self.engine.dialect.driver in ("aiomysql", "asyncmy"): + return cursor.server_side + elif self.engine.dialect.driver == "mysqldb": + sscursor = __import__("MySQLdb.cursors").cursors.SSCursor + return isinstance(cursor, sscursor) + elif self.engine.dialect.driver == "mariadbconnector": + return not cursor.buffered + elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"): + return cursor.server_side + elif self.engine.dialect.driver == "pg8000": + return getattr(cursor, "server_side", False) + else: + return False + + def _fixture(self, server_side_cursors): + if server_side_cursors: + with testing.expect_deprecated( + "The create_engine.server_side_cursors parameter is " + "deprecated and will be removed in a future release. " + "Please use the Connection.execution_options.stream_results " + "parameter." + ): + self.engine = engines.testing_engine( + options={"server_side_cursors": server_side_cursors} + ) + else: + self.engine = engines.testing_engine( + options={"server_side_cursors": server_side_cursors} + ) + return self.engine + + @testing.combinations( + ("global_string", True, "select 1", True), + ("global_text", True, text("select 1"), True), + ("global_expr", True, select(1), True), + ("global_off_explicit", False, text("select 1"), False), + ( + "stmt_option", + False, + select(1).execution_options(stream_results=True), + True, + ), + ( + "stmt_option_disabled", + True, + select(1).execution_options(stream_results=False), + False, + ), + ("for_update_expr", True, select(1).with_for_update(), True), + # TODO: need a real requirement for this, or dont use this test + ( + "for_update_string", + True, + "SELECT 1 FOR UPDATE", + True, + testing.skip_if("sqlite"), + ), + ("text_no_ss", False, text("select 42"), False), + ( + "text_ss_option", + False, + text("select 42").execution_options(stream_results=True), + True, + ), + id_="iaaa", + argnames="engine_ss_arg, statement, cursor_ss_status", + ) + def test_ss_cursor_status( + self, engine_ss_arg, statement, cursor_ss_status + ): + engine = self._fixture(engine_ss_arg) + with engine.begin() as conn: + if isinstance(statement, util.string_types): + result = conn.exec_driver_sql(statement) + else: + result = conn.execute(statement) + eq_(self._is_server_side(result.cursor), cursor_ss_status) + result.close() + + def test_conn_option(self): + engine = self._fixture(False) + + with engine.connect() as conn: + # should be enabled for this one + result = conn.execution_options( + stream_results=True + ).exec_driver_sql("select 1") + assert self._is_server_side(result.cursor) + + def test_stmt_enabled_conn_option_disabled(self): + engine = self._fixture(False) + + s = select(1).execution_options(stream_results=True) + + with engine.connect() as conn: + # not this one + result = conn.execution_options(stream_results=False).execute(s) + assert not self._is_server_side(result.cursor) + + def test_aliases_and_ss(self): + engine = self._fixture(False) + s1 = ( + select(sql.literal_column("1").label("x")) + .execution_options(stream_results=True) + .subquery() + ) + + # options don't propagate out when subquery is used as a FROM clause + with engine.begin() as conn: + result = conn.execute(s1.select()) + assert not self._is_server_side(result.cursor) + result.close() + + s2 = select(1).select_from(s1) + with engine.begin() as conn: + result = conn.execute(s2) + assert not self._is_server_side(result.cursor) + result.close() + + def test_roundtrip_fetchall(self, metadata): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + with engine.begin() as connection: + test_table.create(connection, checkfirst=True) + connection.execute(test_table.insert(), dict(data="data1")) + connection.execute(test_table.insert(), dict(data="data2")) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2")], + ) + connection.execute( + test_table.update() + .where(test_table.c.id == 2) + .values(data=test_table.c.data + " updated") + ) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) + connection.execute(test_table.delete()) + eq_( + connection.scalar( + select(func.count("*")).select_from(test_table) + ), + 0, + ) + + def test_roundtrip_fetchmany(self, metadata): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + with engine.begin() as connection: + test_table.create(connection, checkfirst=True) + connection.execute( + test_table.insert(), + [dict(data="data%d" % i) for i in range(1, 20)], + ) + + result = connection.execute( + test_table.select().order_by(test_table.c.id) + ) + + eq_( + result.fetchmany(5), + [(i, "data%d" % i) for i in range(1, 6)], + ) + eq_( + result.fetchmany(10), + [(i, "data%d" % i) for i in range(6, 16)], + ) + eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)]) diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py new file mode 100644 index 0000000..82e831f --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_rowcount.py @@ -0,0 +1,165 @@ +from sqlalchemy import bindparam +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import text +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures + + +class RowCountTest(fixtures.TablesTest): + """test rowcount functionality""" + + __requires__ = ("sane_rowcount",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "employees", + metadata, + Column( + "employee_id", + Integer, + autoincrement=False, + primary_key=True, + ), + Column("name", String(50)), + Column("department", String(1)), + ) + + @classmethod + def insert_data(cls, connection): + cls.data = data = [ + ("Angela", "A"), + ("Andrew", "A"), + ("Anand", "A"), + ("Bob", "B"), + ("Bobette", "B"), + ("Buffy", "B"), + ("Charlie", "C"), + ("Cynthia", "C"), + ("Chris", "C"), + ] + + employees_table = cls.tables.employees + connection.execute( + employees_table.insert(), + [ + {"employee_id": i, "name": n, "department": d} + for i, (n, d) in enumerate(data) + ], + ) + + def test_basic(self, connection): + employees_table = self.tables.employees + s = select( + employees_table.c.name, employees_table.c.department + ).order_by(employees_table.c.employee_id) + rows = connection.execute(s).fetchall() + + eq_(rows, self.data) + + def test_update_rowcount1(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows changed + department = employees_table.c.department + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "Z"}, + ) + assert r.rowcount == 3 + + def test_update_rowcount2(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 0 rows changed + department = employees_table.c.department + + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "C"}, + ) + eq_(r.rowcount, 3) + + @testing.requires.sane_rowcount_w_returning + def test_update_rowcount_return_defaults(self, connection): + employees_table = self.tables.employees + + department = employees_table.c.department + stmt = ( + employees_table.update() + .where(department == "C") + .values(name=employees_table.c.department + "Z") + .return_defaults() + ) + + r = connection.execute(stmt) + eq_(r.rowcount, 3) + + def test_raw_sql_rowcount(self, connection): + # test issue #3622, make sure eager rowcount is called for text + result = connection.exec_driver_sql( + "update employees set department='Z' where department='C'" + ) + eq_(result.rowcount, 3) + + def test_text_rowcount(self, connection): + # test issue #3622, make sure eager rowcount is called for text + result = connection.execute( + text("update employees set department='Z' " "where department='C'") + ) + eq_(result.rowcount, 3) + + def test_delete_rowcount(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows deleted + department = employees_table.c.department + r = connection.execute( + employees_table.delete().where(department == "C") + ) + eq_(r.rowcount, 3) + + @testing.requires.sane_multi_rowcount + def test_multi_update_rowcount(self, connection): + employees_table = self.tables.employees + stmt = ( + employees_table.update() + .where(employees_table.c.name == bindparam("emp_name")) + .values(department="C") + ) + + r = connection.execute( + stmt, + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], + ) + + eq_(r.rowcount, 2) + + @testing.requires.sane_multi_rowcount + def test_multi_delete_rowcount(self, connection): + employees_table = self.tables.employees + + stmt = employees_table.delete().where( + employees_table.c.name == bindparam("emp_name") + ) + + r = connection.execute( + stmt, + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], + ) + + eq_(r.rowcount, 2) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py new file mode 100644 index 0000000..cb78fff --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -0,0 +1,1783 @@ +import itertools + +from .. import AssertsCompiledSQL +from .. import AssertsExecutionResults +from .. import config +from .. import fixtures +from ..assertions import assert_raises +from ..assertions import eq_ +from ..assertions import in_ +from ..assertsql import CursorSQL +from ..schema import Column +from ..schema import Table +from ... import bindparam +from ... import case +from ... import column +from ... import Computed +from ... import exists +from ... import false +from ... import ForeignKey +from ... import func +from ... import Identity +from ... import Integer +from ... import literal +from ... import literal_column +from ... import null +from ... import select +from ... import String +from ... import table +from ... import testing +from ... import text +from ... import true +from ... import tuple_ +from ... import TupleType +from ... import union +from ... import util +from ... import values +from ...exc import DatabaseError +from ...exc import ProgrammingError +from ...util import collections_abc + + +class CollateTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "collate data1"}, + {"id": 2, "data": "collate data2"}, + ], + ) + + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) + + @testing.requires.order_by_collation + def test_collate_order_by(self): + collation = testing.requires.get_order_by_collation(testing.config) + + self._assert_result( + select(self.tables.some_table).order_by( + self.tables.some_table.c.data.collate(collation).asc() + ), + [(1, "collate data1"), (2, "collate data2")], + ) + + +class OrderByLabelTest(fixtures.TablesTest): + """Test the dialect sends appropriate ORDER BY expressions when + labels are used. + + This essentially exercises the "supports_simple_order_by_label" + setting. + + """ + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, + {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, + {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, + ], + ) + + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) + + def test_plain(self): + table = self.tables.some_table + lx = table.c.x.label("lx") + self._assert_result(select(lx).order_by(lx), [(1,), (2,), (3,)]) + + def test_composed_int(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select(lx).order_by(lx), [(3,), (5,), (7,)]) + + def test_composed_multiple(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + ly = (func.lower(table.c.q) + table.c.p).label("ly") + self._assert_result( + select(lx, ly).order_by(lx, ly.desc()), + [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))], + ) + + def test_plain_desc(self): + table = self.tables.some_table + lx = table.c.x.label("lx") + self._assert_result(select(lx).order_by(lx.desc()), [(3,), (2,), (1,)]) + + def test_composed_int_desc(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select(lx).order_by(lx.desc()), [(7,), (5,), (3,)]) + + @testing.requires.group_by_complex_expression + def test_group_by_composed(self): + table = self.tables.some_table + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select(func.count(table.c.id), expr).group_by(expr).order_by(expr) + ) + self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)]) + + +class ValuesExpressionTest(fixtures.TestBase): + __requires__ = ("table_value_constructor",) + + __backend__ = True + + def test_tuples(self, connection): + value_expr = values( + column("id", Integer), column("name", String), name="my_values" + ).data([(1, "name1"), (2, "name2"), (3, "name3")]) + + eq_( + connection.execute(select(value_expr)).all(), + [(1, "name1"), (2, "name2"), (3, "name3")], + ) + + +class FetchLimitOffsetTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + {"id": 5, "x": 4, "y": 6}, + ], + ) + + def _assert_result( + self, connection, select, result, params=(), set_=False + ): + if set_: + query_res = connection.execute(select, params).fetchall() + eq_(len(query_res), len(result)) + eq_(set(query_res), set(result)) + + else: + eq_(connection.execute(select, params).fetchall(), result) + + def _assert_result_str(self, select, result, params=()): + conn = config.db.connect(close_with_result=True) + eq_(conn.exec_driver_sql(select, params).fetchall(), result) + + def test_simple_limit(self, connection): + table = self.tables.some_table + stmt = select(table).order_by(table.c.id) + self._assert_result( + connection, + stmt.limit(2), + [(1, 1, 2), (2, 2, 3)], + ) + self._assert_result( + connection, + stmt.limit(3), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + def test_limit_render_multiple_times(self, connection): + table = self.tables.some_table + stmt = select(table.c.id).limit(1).scalar_subquery() + + u = union(select(stmt), select(stmt)).subquery().select() + + self._assert_result( + connection, + u, + [ + (1,), + ], + ) + + @testing.requires.fetch_first + def test_simple_fetch(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(2), + [(1, 1, 2), (2, 2, 3)], + ) + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(3), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.offset + def test_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.combinations( + ([(2, 0), (2, 1), (3, 2)]), + ([(2, 1), (2, 0), (3, 2)]), + ([(3, 1), (2, 1), (3, 1)]), + argnames="cases", + ) + @testing.requires.offset + def test_simple_limit_offset(self, connection, cases): + table = self.tables.some_table + connection = connection.execution_options(compiled_cache={}) + + assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)] + + for limit, offset in cases: + expected = assert_data[offset : offset + limit] + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(limit).offset(offset), + expected, + ) + + @testing.requires.fetch_first + def test_simple_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(2).offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(3).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.fetch_no_order_by + def test_fetch_offset_no_order(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).fetch(10), + [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.offset + def test_simple_offset_zero(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(0), + [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(1), + [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.offset + def test_limit_offset_nobinds(self): + """test that 'literal binds' mode works - no bound params.""" + + table = self.tables.some_table + stmt = select(table).order_by(table.c.id).limit(2).offset(1) + sql = stmt.compile( + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) + sql = str(sql) + + self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)]) + + @testing.requires.fetch_first + def test_fetch_offset_nobinds(self): + """test that 'literal binds' mode works - no bound params.""" + + table = self.tables.some_table + stmt = select(table).order_by(table.c.id).fetch(2).offset(1) + sql = stmt.compile( + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) + sql = str(sql) + + self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)]) + + @testing.requires.bound_limit_offset + def test_bound_limit(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(bindparam("l")), + [(1, 1, 2), (2, 2, 3)], + params={"l": 2}, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(bindparam("l")), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + params={"l": 3}, + ) + + @testing.requires.bound_limit_offset + def test_bound_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"o": 2}, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"o": 1}, + ) + + @testing.requires.bound_limit_offset + def test_bound_limit_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4)], + params={"l": 2, "o": 1}, + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"l": 3, "o": 2}, + ) + + @testing.requires.fetch_first + def test_bound_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(bindparam("f")) + .offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4)], + params={"f": 2, "o": 1}, + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(bindparam("f")) + .offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"f": 3, "o": 2}, + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .offset(literal_column("1") + literal_column("2")), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("2")), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + @testing.requires.fetch_first + @testing.requires.fetch_expression + def test_expr_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(literal_column("1") + literal_column("1")) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + @testing.requires.sql_expression_limit_offset + def test_simple_limit_expr_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(2) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(3) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(2), + [(3, 3, 4), (4, 4, 5)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.fetch_ties + def test_simple_fetch_ties(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.x.desc()).fetch(1, with_ties=True), + [(4, 4, 5), (5, 4, 6)], + set_=True, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.x.desc()).fetch(3, with_ties=True), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_offset_with_options + def test_fetch_offset_ties(self, connection): + table = self.tables.some_table + fa = connection.execute( + select(table) + .order_by(table.c.x) + .fetch(2, with_ties=True) + .offset(2) + ).fetchall() + eq_(fa[0], (3, 3, 4)) + eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + + @testing.requires.fetch_ties + @testing.requires.fetch_offset_with_options + def test_fetch_offset_ties_exact_number(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.x) + .fetch(2, with_ties=True) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.x) + .fetch(3, with_ties=True) + .offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.fetch_percent + def test_simple_fetch_percent(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(20, percent=True), + [(1, 1, 2)], + ) + + @testing.requires.fetch_percent + @testing.requires.fetch_offset_with_options + def test_fetch_offset_percent(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(40, percent=True) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_percent + def test_simple_fetch_percent_ties(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.x.desc()) + .fetch(20, percent=True, with_ties=True), + [(4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_percent + @testing.requires.fetch_offset_with_options + def test_fetch_offset_percent_ties(self, connection): + table = self.tables.some_table + fa = connection.execute( + select(table) + .order_by(table.c.x) + .fetch(40, percent=True, with_ties=True) + .offset(2) + ).fetchall() + eq_(fa[0], (3, 3, 4)) + eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + + +class JoinTest(fixtures.TablesTest): + __backend__ = True + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + @classmethod + def define_tables(cls, metadata): + Table("a", metadata, Column("id", Integer, primary_key=True)) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("a_id", ForeignKey("a.id"), nullable=False), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.a.insert(), + [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}], + ) + + connection.execute( + cls.tables.b.insert(), + [ + {"id": 1, "a_id": 1}, + {"id": 2, "a_id": 1}, + {"id": 4, "a_id": 2}, + {"id": 5, "a_id": 3}, + ], + ) + + def test_inner_join_fk(self): + a, b = self.tables("a", "b") + + stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id) + + self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)]) + + def test_inner_join_true(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.join(b, true())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result( + stmt, + [ + (a, b, c) + for (a,), (b, c) in itertools.product( + [(1,), (2,), (3,), (4,), (5,)], + [(1, 1), (2, 1), (4, 2), (5, 3)], + ) + ], + ) + + def test_inner_join_false(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.join(b, false())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result(stmt, []) + + def test_outer_join_false(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.outerjoin(b, false())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result( + stmt, + [ + (1, None, None), + (2, None, None), + (3, None, None), + (4, None, None), + (5, None, None), + ], + ) + + def test_outer_join_fk(self): + a, b = self.tables("a", "b") + + stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id) + + self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)]) + + +class CompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + def test_plain_union(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_select_from_plain_union(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.order_by_col_from_union + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_in_unions_from_alias(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + # this necessarily has double parens + u1 = union(s1, s2).alias() + self._assert_result( + u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self): + table = self.tables.some_table + s1 = ( + select(table) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select(table) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + +class PostCompileParamsTest( + AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest +): + __backend__ = True + + __requires__ = ("standard_cursor_sql",) + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "z": "z1"}, + {"id": 2, "x": 2, "y": 3, "z": "z2"}, + {"id": 3, "x": 3, "y": 4, "z": "z3"}, + {"id": 4, "x": 4, "y": 5, "z": "z4"}, + ], + ) + + def test_compile(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", literal_execute=True) + ) + + self.assert_compile( + stmt, + "SELECT some_table.id FROM some_table " + "WHERE some_table.x = __[POSTCOMPILE_q]", + {}, + ) + + def test_compile_literal_binds(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", 10, literal_execute=True) + ) + + self.assert_compile( + stmt, + "SELECT some_table.id FROM some_table WHERE some_table.x = 10", + {}, + literal_binds=True, + ) + + def test_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", literal_execute=True) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=10)) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x = 10", + () if config.db.dialect.positional else {}, + ) + ) + + def test_execute_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x.in_(bindparam("q", expanding=True, literal_execute=True)) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[5, 6, 7])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x IN (5, 6, 7)", + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, 10), (12, 18)])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.y) " + "IN (%s(5, 10), (12, 18))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.z) " + "IN (%s(5, 'z1'), (12, 'z3'))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + +class ExpandingBoundInTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "z": "z1"}, + {"id": 2, "x": 2, "y": 3, "z": "z2"}, + {"id": 3, "x": 3, "y": 4, "z": "z3"}, + {"id": 4, "x": 4, "y": 5, "z": "z4"}, + ], + ) + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + def test_multiple_empty_sets_bindparam(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .where(table.c.y.in_(bindparam("p"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": [], "p": []}) + + def test_multiple_empty_sets_direct(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([])) + .where(table.c.y.in_([])) + .order_by(table.c.id) + ) + self._assert_result(stmt, []) + + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_direct(self): + table = self.tables.some_table + + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)]) + go([], []) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_direct(self): + table = self.tables.some_table + + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)]) + go([], []) + + def test_bound_in_scalar_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) + + def test_bound_in_scalar_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([2, 3, 4])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + def test_nonempty_in_plus_empty_notin(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([2, 3])) + .where(table.c.id.not_in([])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,)]) + + def test_empty_in_plus_notempty_notin(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([])) + .where(table.c.id.not_in([2, 3])) + .order_by(table.c.id) + ) + self._assert_result(stmt, []) + + def test_typed_str_in(self): + """test related to #7292. + + as a type is given to the bound param, there is no ambiguity + to the type of element. + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", type_=String, expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + def test_untyped_str_in(self): + """test related to #7292. + + for untyped expression, we look at the types of elements. + Test for Sequence to detect tuple in. but not strings or bytes! + as always.... + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where( + tuple_(table.c.x, table.c.z).in_( + [(2, "z2"), (3, "z3"), (4, "z4")] + ) + ) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self): + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams( + bindparam( + "q", type_=TupleType(Integer(), String()), expanding=True + ) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + def test_empty_set_against_integer_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + def test_empty_set_against_integer_direct(self): + table = self.tables.some_table + stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + + def test_empty_set_against_integer_negation_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.not_in(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + + def test_empty_set_against_integer_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_empty_set_against_string_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.z.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + def test_empty_set_against_string_direct(self): + table = self.tables.some_table + stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + + def test_empty_set_against_string_negation_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.z.not_in(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + + def test_empty_set_against_string_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_null_in_empty_set_is_false_bindparam(self, connection): + stmt = select( + case( + ( + null().in_(bindparam("foo", value=())), + true(), + ), + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + def test_null_in_empty_set_is_false_direct(self, connection): + stmt = select( + case( + ( + null().in_([]), + true(), + ), + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + +class LikeFunctionsTest(fixtures.TablesTest): + __backend__ = True + + run_inserts = "once" + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "abcdefg"}, + {"id": 2, "data": "ab/cdefg"}, + {"id": 3, "data": "ab%cdefg"}, + {"id": 4, "data": "ab_cdefg"}, + {"id": 5, "data": "abcde/fg"}, + {"id": 6, "data": "abcde%fg"}, + {"id": 7, "data": "ab#cdefg"}, + {"id": 8, "data": "ab9cdefg"}, + {"id": 9, "data": "abcde#fg"}, + {"id": 10, "data": "abcd9fg"}, + {"id": 11, "data": None}, + ], + ) + + def _test(self, expr, expected): + some_table = self.tables.some_table + + with config.db.connect() as conn: + rows = { + value + for value, in conn.execute(select(some_table.c.id).where(expr)) + } + + eq_(rows, expected) + + def test_startswith_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + + def test_startswith_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c", autoescape=True), {3}) + + def test_startswith_sqlexpr(self): + col = self.tables.some_table.c.data + self._test( + col.startswith(literal_column("'ab%c'")), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + ) + + def test_startswith_escape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab##c", escape="#"), {7}) + + def test_startswith_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3}) + self._test(col.startswith("ab#c", autoescape=True, escape="#"), {7}) + + def test_endswith_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + + def test_endswith_sqlexpr(self): + col = self.tables.some_table.c.data + self._test( + col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9} + ) + + def test_endswith_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg", autoescape=True), {6}) + + def test_endswith_escape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e##fg", escape="#"), {9}) + + def test_endswith_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6}) + self._test(col.endswith("e#fg", autoescape=True, escape="#"), {9}) + + def test_contains_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + + def test_contains_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cde", autoescape=True), {3}) + + def test_contains_escape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b##cde", escape="#"), {7}) + + def test_contains_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) + self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) + + @testing.requires.regexp_match + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10}) + + @testing.requires.regexp_replace + def test_regexp_replace(self): + col = self.tables.some_table.c.data + self._test( + col.regexp_replace("a.cde", "FOO").contains("FOO"), {1, 5, 6, 9} + ) + + @testing.requires.regexp_match + @testing.combinations( + ("a.cde", {1, 5, 6, 9}), + ("abc", {1, 5, 6, 9, 10}), + ("^abc", {1, 5, 6, 9, 10}), + ("9cde", {8}), + ("^a", set(range(1, 11))), + ("(b|c)", set(range(1, 11))), + ("^(b|c)", set()), + ) + def test_regexp_match(self, text, expected): + col = self.tables.some_table.c.data + self._test(col.regexp_match(text), expected) + + +class ComputedColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("computed_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "square", + metadata, + Column("id", Integer, primary_key=True), + Column("side", Integer), + Column("area", Integer, Computed("side * side")), + Column("perimeter", Integer, Computed("4 * side")), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.square.insert(), + [{"id": 1, "side": 10}, {"id": 10, "side": 42}], + ) + + def test_select_all(self): + with config.db.connect() as conn: + res = conn.execute( + select(text("*")) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)]) + + def test_select_columns(self): + with config.db.connect() as conn: + res = conn.execute( + select( + self.tables.square.c.area, self.tables.square.c.perimeter + ) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(100, 40), (1764, 168)]) + + +class IdentityColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("identity_columns",) + run_inserts = "once" + run_deletes = "once" + + @classmethod + def define_tables(cls, metadata): + Table( + "tbl_a", + metadata, + Column( + "id", + Integer, + Identity( + always=True, start=42, nominvalue=True, nomaxvalue=True + ), + primary_key=True, + ), + Column("desc", String(100)), + ) + Table( + "tbl_b", + metadata, + Column( + "id", + Integer, + Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0), + primary_key=True, + ), + Column("desc", String(100)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.tbl_a.insert(), + [{"desc": "a"}, {"desc": "b"}], + ) + connection.execute( + cls.tables.tbl_b.insert(), + [{"desc": "a"}, {"desc": "b"}], + ) + connection.execute( + cls.tables.tbl_b.insert(), + [{"id": 42, "desc": "c"}], + ) + + def test_select_all(self, connection): + res = connection.execute( + select(text("*")) + .select_from(self.tables.tbl_a) + .order_by(self.tables.tbl_a.c.id) + ).fetchall() + eq_(res, [(42, "a"), (43, "b")]) + + res = connection.execute( + select(text("*")) + .select_from(self.tables.tbl_b) + .order_by(self.tables.tbl_b.c.id) + ).fetchall() + eq_(res, [(-5, "b"), (0, "a"), (42, "c")]) + + def test_select_columns(self, connection): + + res = connection.execute( + select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id) + ).fetchall() + eq_(res, [(42,), (43,)]) + + @testing.requires.identity_columns_standard + def test_insert_always_error(self, connection): + def fn(): + connection.execute( + self.tables.tbl_a.insert(), + [{"id": 200, "desc": "a"}], + ) + + assert_raises((DatabaseError, ProgrammingError), fn) + + +class IdentityAutoincrementTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("autoincrement_without_sequence",) + + @classmethod + def define_tables(cls, metadata): + Table( + "tbl", + metadata, + Column( + "id", + Integer, + Identity(), + primary_key=True, + autoincrement=True, + ), + Column("desc", String(100)), + ) + + def test_autoincrement_with_identity(self, connection): + res = connection.execute(self.tables.tbl.insert(), {"desc": "row"}) + res = connection.execute(self.tables.tbl.select()).first() + eq_(res, (1, "row")) + + +class ExistsTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "stuff", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.stuff.insert(), + [ + {"id": 1, "data": "some data"}, + {"id": 2, "data": "some data"}, + {"id": 3, "data": "some data"}, + {"id": 4, "data": "some other data"}, + ], + ) + + def test_select_exists(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select(literal(1)).where( + exists().where(stuff.c.data == "some data") + ) + ).fetchall(), + [(1,)], + ) + + def test_select_exists_false(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select(literal(1)).where( + exists().where(stuff.c.data == "no data") + ) + ).fetchall(), + [], + ) + + +class DistinctOnTest(AssertsCompiledSQL, fixtures.TablesTest): + __backend__ = True + + @testing.fails_if(testing.requires.supports_distinct_on) + def test_distinct_on(self): + stm = select("*").distinct(column("q")).select_from(table("foo")) + with testing.expect_deprecated( + "DISTINCT ON is currently supported only by the PostgreSQL " + ): + self.assert_compile(stm, "SELECT DISTINCT * FROM foo") + + +class IsOrIsNotDistinctFromTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("supports_is_distinct_from",) + + @classmethod + def define_tables(cls, metadata): + Table( + "is_distinct_test", + metadata, + Column("id", Integer, primary_key=True), + Column("col_a", Integer, nullable=True), + Column("col_b", Integer, nullable=True), + ) + + @testing.combinations( + ("both_int_different", 0, 1, 1), + ("both_int_same", 1, 1, 0), + ("one_null_first", None, 1, 1), + ("one_null_second", 0, None, 1), + ("both_null", None, None, 0), + id_="iaaa", + argnames="col_a_value, col_b_value, expected_row_count_for_is", + ) + def test_is_or_is_not_distinct_from( + self, col_a_value, col_b_value, expected_row_count_for_is, connection + ): + tbl = self.tables.is_distinct_test + + connection.execute( + tbl.insert(), + [{"id": 1, "col_a": col_a_value, "col_b": col_b_value}], + ) + + result = connection.execute( + tbl.select().where(tbl.c.col_a.is_distinct_from(tbl.c.col_b)) + ).fetchall() + eq_( + len(result), + expected_row_count_for_is, + ) + + expected_row_count_for_is_not = ( + 1 if expected_row_count_for_is == 0 else 0 + ) + result = connection.execute( + tbl.select().where(tbl.c.col_a.is_not_distinct_from(tbl.c.col_b)) + ).fetchall() + eq_( + len(result), + expected_row_count_for_is_not, + ) diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py new file mode 100644 index 0000000..d6747d2 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -0,0 +1,282 @@ +from .. import config +from .. import fixtures +from ..assertions import eq_ +from ..assertions import is_true +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import inspect +from ... import Integer +from ... import MetaData +from ... import Sequence +from ... import String +from ... import testing + + +class SequenceTest(fixtures.TablesTest): + __requires__ = ("sequences",) + __backend__ = True + + run_create_tables = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "seq_pk", + metadata, + Column( + "id", + Integer, + Sequence("tab_id_seq"), + primary_key=True, + ), + Column("data", String(50)), + ) + + Table( + "seq_opt_pk", + metadata, + Column( + "id", + Integer, + Sequence("tab_id_seq", data_type=Integer, optional=True), + primary_key=True, + ), + Column("data", String(50)), + ) + + Table( + "seq_no_returning", + metadata, + Column( + "id", + Integer, + Sequence("noret_id_seq"), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + ) + + if testing.requires.schemas.enabled: + Table( + "seq_no_returning_sch", + metadata, + Column( + "id", + Integer, + Sequence("noret_sch_id_seq", schema=config.test_schema), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + schema=config.test_schema, + ) + + def test_insert_roundtrip(self, connection): + connection.execute(self.tables.seq_pk.insert(), dict(data="some data")) + self._assert_round_trip(self.tables.seq_pk, connection) + + def test_insert_lastrowid(self, connection): + r = connection.execute( + self.tables.seq_pk.insert(), dict(data="some data") + ) + eq_( + r.inserted_primary_key, (testing.db.dialect.default_sequence_base,) + ) + + def test_nextval_direct(self, connection): + r = connection.execute(self.tables.seq_pk.c.id.default) + eq_(r, testing.db.dialect.default_sequence_base) + + @requirements.sequences_optional + def test_optional_seq(self, connection): + r = connection.execute( + self.tables.seq_opt_pk.insert(), dict(data="some data") + ) + eq_(r.inserted_primary_key, (1,)) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_(row, (testing.db.dialect.default_sequence_base, "some data")) + + def test_insert_roundtrip_no_implicit_returning(self, connection): + connection.execute( + self.tables.seq_no_returning.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.seq_no_returning, connection) + + @testing.combinations((True,), (False,), argnames="implicit_returning") + @testing.requires.schemas + def test_insert_roundtrip_translate(self, connection, implicit_returning): + + seq_no_returning = Table( + "seq_no_returning_sch", + MetaData(), + Column( + "id", + Integer, + Sequence("noret_sch_id_seq", schema="alt_schema"), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=implicit_returning, + schema="alt_schema", + ) + + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + connection.execute(seq_no_returning.insert(), dict(data="some data")) + self._assert_round_trip(seq_no_returning, connection) + + @testing.requires.schemas + def test_nextval_direct_schema_translate(self, connection): + seq = Sequence("noret_sch_id_seq", schema="alt_schema") + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + + r = connection.execute(seq) + eq_(r, testing.db.dialect.default_sequence_base) + + +class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_literal_binds_inline_compile(self, connection): + table = Table( + "x", + MetaData(), + Column("y", Integer, Sequence("y_seq")), + Column("q", Integer), + ) + + stmt = table.insert().values(q=5) + + seq_nextval = connection.dialect.statement_compiler( + statement=None, dialect=connection.dialect + ).visit_sequence(Sequence("y_seq")) + self.assert_compile( + stmt, + "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,), + literal_binds=True, + dialect=connection.dialect, + ) + + +class HasSequenceTest(fixtures.TablesTest): + run_deletes = None + + __requires__ = ("sequences",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + Sequence( + "other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True + ) + if testing.requires.schemas.enabled: + Sequence( + "user_id_seq", schema=config.test_schema, metadata=metadata + ) + Sequence( + "schema_seq", schema=config.test_schema, metadata=metadata + ) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + def test_has_sequence(self, connection): + eq_( + inspect(connection).has_sequence("user_id_seq"), + True, + ) + + def test_has_sequence_other_object(self, connection): + eq_( + inspect(connection).has_sequence("user_id_table"), + False, + ) + + @testing.requires.schemas + def test_has_sequence_schema(self, connection): + eq_( + inspect(connection).has_sequence( + "user_id_seq", schema=config.test_schema + ), + True, + ) + + def test_has_sequence_neg(self, connection): + eq_( + inspect(connection).has_sequence("some_sequence"), + False, + ) + + @testing.requires.schemas + def test_has_sequence_schemas_neg(self, connection): + eq_( + inspect(connection).has_sequence( + "some_sequence", schema=config.test_schema + ), + False, + ) + + @testing.requires.schemas + def test_has_sequence_default_not_in_remote(self, connection): + eq_( + inspect(connection).has_sequence( + "other_sequence", schema=config.test_schema + ), + False, + ) + + @testing.requires.schemas + def test_has_sequence_remote_not_in_default(self, connection): + eq_( + inspect(connection).has_sequence("schema_seq"), + False, + ) + + def test_get_sequence_names(self, connection): + exp = {"other_seq", "user_id_seq"} + + res = set(inspect(connection).get_sequence_names()) + is_true(res.intersection(exp) == exp) + is_true("schema_seq" not in res) + + @testing.requires.schemas + def test_get_sequence_names_no_sequence_schema(self, connection): + eq_( + inspect(connection).get_sequence_names( + schema=config.test_schema_2 + ), + [], + ) + + @testing.requires.schemas + def test_get_sequence_names_sequences_schema(self, connection): + eq_( + sorted( + inspect(connection).get_sequence_names( + schema=config.test_schema + ) + ), + ["schema_seq", "user_id_seq"], + ) + + +class HasSequenceTestEmpty(fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_get_sequence_names_no_sequence(self, connection): + eq_( + inspect(connection).get_sequence_names(), + [], + ) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py new file mode 100644 index 0000000..b96350e --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -0,0 +1,1508 @@ +# coding: utf-8 + +import datetime +import decimal +import json +import re + +from .. import config +from .. import engines +from .. import fixtures +from .. import mock +from ..assertions import eq_ +from ..assertions import is_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import and_ +from ... import BigInteger +from ... import bindparam +from ... import Boolean +from ... import case +from ... import cast +from ... import Date +from ... import DateTime +from ... import Float +from ... import Integer +from ... import JSON +from ... import literal +from ... import MetaData +from ... import null +from ... import Numeric +from ... import select +from ... import String +from ... import testing +from ... import Text +from ... import Time +from ... import TIMESTAMP +from ... import TypeDecorator +from ... import Unicode +from ... import UnicodeText +from ... import util +from ...orm import declarative_base +from ...orm import Session +from ...sql.sqltypes import LargeBinary +from ...sql.sqltypes import PickleType +from ...util import compat +from ...util import u + + +class _LiteralRoundTripFixture(object): + supports_whereclause = True + + @testing.fixture + def literal_round_trip(self, metadata, connection): + """test literal rendering""" + + # for literal, we test the literal render in an INSERT + # into a typed column. we can then SELECT it back as its + # official type; ideally we'd be able to use CAST here + # but MySQL in particular can't CAST fully + + def run(type_, input_, output, filter_=None): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + for value in input_: + ins = ( + t.insert() + .values(x=literal(value, type_)) + .compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) + ) + connection.execute(ins) + + if self.supports_whereclause: + stmt = t.select().where(t.c.x == literal(value)) + else: + stmt = t.select() + + stmt = stmt.compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) + for row in connection.execute(stmt): + value = row[0] + if filter_ is not None: + value = filter_(value) + assert value in output + + return run + + +class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): + __requires__ = ("unicode_data",) + + data = u( + "Alors vous imaginez ma 🐍 surprise, au lever du jour, " + "quand une drôle de petite 🐍 voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »" + ) + + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + + @classmethod + def define_tables(cls, metadata): + Table( + "unicode_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("unicode_data", cls.datatype), + ) + + def test_round_trip(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": self.data} + ) + + row = connection.execute(select(unicode_table.c.unicode_data)).first() + + eq_(row, (self.data,)) + assert isinstance(row[0], util.text_type) + + def test_round_trip_executemany(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), + [{"id": i, "unicode_data": self.data} for i in range(1, 4)], + ) + + rows = connection.execute( + select(unicode_table.c.unicode_data) + ).fetchall() + eq_(rows, [(self.data,) for i in range(1, 4)]) + for row in rows: + assert isinstance(row[0], util.text_type) + + def _test_null_strings(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": None} + ) + row = connection.execute(select(unicode_table.c.unicode_data)).first() + eq_(row, (None,)) + + def _test_empty_strings(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": u("")} + ) + row = connection.execute(select(unicode_table.c.unicode_data)).first() + eq_(row, (u(""),)) + + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) + + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( + self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + + +class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = ("unicode_data",) + __backend__ = True + + datatype = Unicode(255) + + @requirements.empty_strings_varchar + def test_empty_strings_varchar(self, connection): + self._test_empty_strings(connection) + + def test_null_strings_varchar(self, connection): + self._test_null_strings(connection) + + +class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = "unicode_data", "text_type" + __backend__ = True + + datatype = UnicodeText() + + @requirements.empty_strings_text + def test_empty_strings_text(self, connection): + self._test_empty_strings(connection) + + def test_null_strings_text(self, connection): + self._test_null_strings(connection) + + +class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("binary_literals",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "binary_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("binary_data", LargeBinary), + Column("pickle_data", PickleType), + ) + + def test_binary_roundtrip(self, connection): + binary_table = self.tables.binary_table + + connection.execute( + binary_table.insert(), {"id": 1, "binary_data": b"this is binary"} + ) + row = connection.execute(select(binary_table.c.binary_data)).first() + eq_(row, (b"this is binary",)) + + def test_pickle_roundtrip(self, connection): + binary_table = self.tables.binary_table + + connection.execute( + binary_table.insert(), + {"id": 1, "pickle_data": {"foo": [1, 2, 3], "bar": "bat"}}, + ) + row = connection.execute(select(binary_table.c.pickle_data)).first() + eq_(row, ({"foo": [1, 2, 3], "bar": "bat"},)) + + +class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("text_type",) + __backend__ = True + + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + + @classmethod + def define_tables(cls, metadata): + Table( + "text_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("text_data", Text), + ) + + def test_text_roundtrip(self, connection): + text_table = self.tables.text_table + + connection.execute( + text_table.insert(), {"id": 1, "text_data": "some text"} + ) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, ("some text",)) + + @testing.requires.empty_strings_text + def test_text_empty_strings(self, connection): + text_table = self.tables.text_table + + connection.execute(text_table.insert(), {"id": 1, "text_data": ""}) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, ("",)) + + def test_text_null_strings(self, connection): + text_table = self.tables.text_table + + connection.execute(text_table.insert(), {"id": 1, "text_data": None}) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, (None,)) + + def test_literal(self, literal_round_trip): + literal_round_trip(Text, ["some text"], ["some text"]) + + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( + Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + + def test_literal_quoting(self, literal_round_trip): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip(Text, [data], [data]) + + def test_literal_backslashes(self, literal_round_trip): + data = r"backslash one \ backslash two \\ end" + literal_round_trip(Text, [data], [data]) + + def test_literal_percentsigns(self, literal_round_trip): + data = r"percent % signs %% percent" + literal_round_trip(Text, [data], [data]) + + +class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @requirements.unbounded_varchar + def test_nolength_string(self): + metadata = MetaData() + foo = Table("foo", metadata, Column("one", String)) + + foo.create(config.db) + foo.drop(config.db) + + def test_literal(self, literal_round_trip): + # note that in Python 3, this invokes the Unicode + # datatype for the literal part because all strings are unicode + literal_round_trip(String(40), ["some text"], ["some text"]) + + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( + String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + + def test_literal_quoting(self, literal_round_trip): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip(String(40), [data], [data]) + + def test_literal_backslashes(self, literal_round_trip): + data = r"backslash one \ backslash two \\ end" + literal_round_trip(String(40), [data], [data]) + + +class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): + compare = None + + @classmethod + def define_tables(cls, metadata): + class Decorated(TypeDecorator): + impl = cls.datatype + cache_ok = True + + Table( + "date_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("date_data", cls.datatype), + Column("decorated_date_data", Decorated), + ) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + def test_round_trip(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "date_data": self.data} + ) + + row = connection.execute(select(date_table.c.date_data)).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + + def test_round_trip_decorated(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "decorated_date_data": self.data} + ) + + row = connection.execute( + select(date_table.c.decorated_date_data) + ).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + + def test_null(self, connection): + date_table = self.tables.date_table + + connection.execute(date_table.insert(), {"id": 1, "date_data": None}) + + row = connection.execute(select(date_table.c.date_data)).first() + eq_(row, (None,)) + + @testing.requires.datetime_literals + def test_literal(self, literal_round_trip): + compare = self.compare or self.data + literal_round_trip(self.datatype, [self.data], [compare]) + + @testing.requires.standalone_null_binds_whereclause + def test_null_bound_comparison(self): + # this test is based on an Oracle issue observed in #4886. + # passing NULL for an expression that needs to be interpreted as + # a certain type, does the DBAPI have the info it needs to do this. + date_table = self.tables.date_table + with config.db.begin() as conn: + result = conn.execute( + date_table.insert(), {"id": 1, "date_data": self.data} + ) + id_ = result.inserted_primary_key[0] + stmt = select(date_table.c.id).where( + case( + ( + bindparam("foo", type_=self.datatype) != None, + bindparam("foo", type_=self.datatype), + ), + else_=date_table.c.date_data, + ) + == date_table.c.date_data + ) + + row = conn.execute(stmt, {"foo": None}).first() + eq_(row[0], id_) + + +class DateTimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + + +class DateTimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_timezone",) + __backend__ = True + datatype = DateTime(timezone=True) + data = datetime.datetime( + 2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc + ) + + +class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_microseconds",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + + +class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("timestamp_microseconds",) + __backend__ = True + datatype = TIMESTAMP + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + + @testing.requires.timestamp_microseconds_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class TimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time",) + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18) + + +class TimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_timezone",) + __backend__ = True + datatype = Time(timezone=True) + data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc) + + +class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_microseconds",) + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18, 396) + + +class DateTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("date",) + __backend__ = True + datatype = Date + data = datetime.date(2012, 10, 15) + + +class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = "date", "date_coerces_from_datetime" + __backend__ = True + datatype = Date + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + compare = datetime.date(2012, 10, 15) + + +class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_historic",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(1850, 11, 10, 11, 52, 35) + + +class DateHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("date_historic",) + __backend__ = True + datatype = Date + data = datetime.date(1727, 4, 1) + + +class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + def test_literal(self, literal_round_trip): + literal_round_trip(Integer, [5], [5]) + + def test_huge_int(self, integer_round_trip): + integer_round_trip(BigInteger, 1376537018368127) + + @testing.fixture + def integer_round_trip(self, metadata, connection): + def run(datatype, data): + int_table = Table( + "integer_table", + metadata, + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("integer_data", datatype), + ) + + metadata.create_all(config.db) + + connection.execute( + int_table.insert(), {"id": 1, "integer_data": data} + ) + + row = connection.execute(select(int_table.c.integer_data)).first() + + eq_(row, (data,)) + + if util.py3k: + assert isinstance(row[0], int) + else: + assert isinstance(row[0], (long, int)) # noqa + + return run + + +class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def string_as_int(self): + class StringAsInt(TypeDecorator): + impl = String(50) + cache_ok = True + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + def column_expression(self, col): + return cast(col, Integer) + + def bind_expression(self, col): + return cast(col, String(50)) + + return StringAsInt() + + def test_special_type(self, metadata, connection, string_as_int): + + type_ = string_as_int + + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]]) + + result = {row[0] for row in connection.execute(t.select())} + eq_(result, {1, 2, 3}) + + result = { + row[0] for row in connection.execute(t.select().where(t.c.x == 2)) + } + eq_(result, {2}) + + +class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def do_numeric_test(self, metadata, connection): + @testing.emits_warning( + r".*does \*not\* support Decimal objects natively" + ) + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + connection.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + return run + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_render_literal_numeric(self, literal_round_trip): + literal_round_trip( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_render_literal_numeric_asfloat(self, literal_round_trip): + literal_round_trip( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + def test_render_literal_float(self, literal_round_trip): + literal_round_trip( + Float(4), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, + ) + + @testing.requires.precision_generic_float_type + def test_float_custom_scale(self, do_numeric_test): + do_numeric_test( + Float(None, decimal_return_scale=7, asdecimal=True), + [15.7563827, decimal.Decimal("15.7563827")], + [decimal.Decimal("15.7563827")], + check_scale=True, + ) + + def test_numeric_as_decimal(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + def test_numeric_as_float(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + @testing.requires.infinity_floats + def test_infinity_floats(self, do_numeric_test): + """test for #977, #7283""" + + do_numeric_test( + Float(None), + [float("inf")], + [float("inf")], + ) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_decimal(self, do_numeric_test): + do_numeric_test(Numeric(precision=8, scale=4), [None], [None]) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_float(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), [None], [None] + ) + + @testing.requires.floats_to_four_decimals + def test_float_as_decimal(self, do_numeric_test): + do_numeric_test( + Float(precision=8, asdecimal=True), + [15.7563, decimal.Decimal("15.7563"), None], + [decimal.Decimal("15.7563"), None], + filter_=lambda n: n is not None and round(n, 4) or None, + ) + + def test_float_as_float(self, do_numeric_test): + do_numeric_test( + Float(precision=8), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, + ) + + def test_float_coerce_round_trip(self, connection): + expr = 15.7563 + + val = connection.scalar(select(literal(expr))) + eq_(val, expr) + + # this does not work in MySQL, see #4036, however we choose not + # to render CAST unconditionally since this is kind of an edge case. + + @testing.requires.implicit_decimal_binds + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_decimal_coerce_round_trip(self, connection): + expr = decimal.Decimal("15.7563") + + val = connection.scalar(select(literal(expr))) + eq_(val, expr) + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_decimal_coerce_round_trip_w_cast(self, connection): + expr = decimal.Decimal("15.7563") + + val = connection.scalar(select(cast(expr, Numeric(10, 4)))) + eq_(val, expr) + + @testing.requires.precision_numerics_general + def test_precision_decimal(self, do_numeric_test): + numbers = set( + [ + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + ] + ) + + do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal(self, do_numeric_test): + """test exceedingly small decimals. + + Decimal reports values with E notation when the exponent + is greater than 6. + + """ + + numbers = set( + [ + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + ] + ) + do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal_large(self, do_numeric_test): + """test exceedingly large decimals.""" + + numbers = set( + [ + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + ] + ) + do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) + + @testing.requires.precision_numerics_many_significant_digits + def test_many_significant_digits(self, do_numeric_test): + numbers = set( + [ + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + ] + ) + do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) + + @testing.requires.precision_numerics_retains_significant_digits + def test_numeric_no_decimal(self, do_numeric_test): + numbers = set([decimal.Decimal("1.000")]) + do_numeric_test( + Numeric(precision=5, scale=3), numbers, numbers, check_scale=True + ) + + +class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "boolean_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("value", Boolean), + Column("unconstrained_value", Boolean(create_constraint=False)), + ) + + def test_render_literal_bool(self, literal_round_trip): + literal_round_trip(Boolean(), [True, False], [True, False]) + + def test_round_trip(self, connection): + boolean_table = self.tables.boolean_table + + connection.execute( + boolean_table.insert(), + {"id": 1, "value": True, "unconstrained_value": False}, + ) + + row = connection.execute( + select(boolean_table.c.value, boolean_table.c.unconstrained_value) + ).first() + + eq_(row, (True, False)) + assert isinstance(row[0], bool) + + @testing.requires.nullable_booleans + def test_null(self, connection): + boolean_table = self.tables.boolean_table + + connection.execute( + boolean_table.insert(), + {"id": 1, "value": None, "unconstrained_value": None}, + ) + + row = connection.execute( + select(boolean_table.c.value, boolean_table.c.unconstrained_value) + ).first() + + eq_(row, (None, None)) + + def test_whereclause(self): + # testing "WHERE <column>" renders a compatible expression + boolean_table = self.tables.boolean_table + + with config.db.begin() as conn: + conn.execute( + boolean_table.insert(), + [ + {"id": 1, "value": True, "unconstrained_value": True}, + {"id": 2, "value": False, "unconstrained_value": False}, + ], + ) + + eq_( + conn.scalar( + select(boolean_table.c.id).where(boolean_table.c.value) + ), + 1, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where( + boolean_table.c.unconstrained_value + ) + ), + 1, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where(~boolean_table.c.value) + ), + 2, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where( + ~boolean_table.c.unconstrained_value + ) + ), + 2, + ) + + +class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("json_type",) + __backend__ = True + + datatype = JSON + + @classmethod + def define_tables(cls, metadata): + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype, nullable=False), + Column("nulldata", cls.datatype(none_as_null=True)), + ) + + def test_round_trip_data1(self, connection): + self._test_round_trip({"key1": "value1", "key2": "value2"}, connection) + + def _test_round_trip(self, data_element, connection): + data_table = self.tables.data_table + + connection.execute( + data_table.insert(), + {"id": 1, "name": "row1", "data": data_element}, + ) + + row = connection.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + + def _index_fixtures(include_comparison): + + if include_comparison: + # basically SQL Server and MariaDB can kind of do json + # comparison, MySQL, PG and SQLite can't. not worth it. + json_elements = [] + else: + json_elements = [ + ("json", {"foo": "bar"}), + ("json", ["one", "two", "three"]), + (None, {"foo": "bar"}), + (None, ["one", "two", "three"]), + ] + + elements = [ + ("boolean", True), + ("boolean", False), + ("boolean", None), + ("string", "some string"), + ("string", None), + ("string", util.u("réve illé")), + ( + "string", + util.u("réve🐍 illé"), + testing.requires.json_index_supplementary_unicode_element, + ), + ("integer", 15), + ("integer", 1), + ("integer", 0), + ("integer", None), + ("float", 28.5), + ("float", None), + ( + "float", + 1234567.89, + ), + ("numeric", 1234567.89), + # this one "works" because the float value you see here is + # lost immediately to floating point stuff + ("numeric", 99998969694839.983485848, requirements.python3), + ("numeric", 99939.983485848, requirements.python3), + ("_decimal", decimal.Decimal("1234567.89")), + ( + "_decimal", + decimal.Decimal("99998969694839.983485848"), + # fails on SQLite and MySQL (non-mariadb) + requirements.cast_precision_numerics_many_significant_digits, + ), + ( + "_decimal", + decimal.Decimal("99939.983485848"), + ), + ] + json_elements + + def decorate(fn): + fn = testing.combinations(id_="sa", *elements)(fn) + + return fn + + return decorate + + def _json_value_insert(self, connection, datatype, value, data_element): + data_table = self.tables.data_table + if datatype == "_decimal": + + # Python's builtin json serializer basically doesn't support + # Decimal objects without implicit float conversion period. + # users can otherwise use simplejson which supports + # precision decimals + + # https://bugs.python.org/issue16535 + + # inserting as strings to avoid a new fixture around the + # dialect which would have idiosyncrasies for different + # backends. + + class DecimalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, decimal.Decimal): + return str(o) + return super(DecimalEncoder, self).default(o) + + json_data = json.dumps(data_element, cls=DecimalEncoder) + + # take the quotes out. yup, there is *literally* no other + # way to get Python's json.dumps() to put all the digits in + # the string + json_data = re.sub(r'"(%s)"' % str(value), str(value), json_data) + + datatype = "numeric" + + connection.execute( + data_table.insert().values( + name="row1", + # to pass the string directly to every backend, including + # PostgreSQL which needs the value to be CAST as JSON + # both in the SQL as well as at the prepared statement + # level for asyncpg, while at the same time MySQL + # doesn't even support CAST for JSON, here we are + # sending the string embedded in the SQL without using + # a parameter. + data=bindparam(None, json_data, literal_execute=True), + nulldata=bindparam(None, json_data, literal_execute=True), + ), + ) + else: + connection.execute( + data_table.insert(), + { + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + p_s = None + + if datatype: + if datatype == "numeric": + a, b = str(value).split(".") + s = len(b) + p = len(a) + s + + if isinstance(value, decimal.Decimal): + compare_value = value + else: + compare_value = decimal.Decimal(str(value)) + + p_s = (p, s) + else: + compare_value = value + else: + compare_value = value + + return datatype, compare_value, p_s + + @_index_fixtures(False) + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_index_typed_access(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + + with config.db.begin() as conn: + + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data["key1"] + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + roundtrip = conn.scalar(select(expr)) + eq_(roundtrip, compare_value) + if util.py3k: # skip py2k to avoid comparing unicode to str etc. + is_(type(roundtrip), type(compare_value)) + + @_index_fixtures(True) + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_index_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + + with config.db.begin() as conn: + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data["key1"] + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute( + select(expr).where(expr == compare_value) + ).first() + + # make sure we get a row even if value is None + eq_(row, (compare_value,)) + + @_index_fixtures(True) + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_path_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": {"subkey1": value}} + with config.db.begin() as conn: + + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data[("key1", "subkey1")] + + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute( + select(expr).where(expr == compare_value) + ).first() + + # make sure we get a row even if value is None + eq_(row, (compare_value,)) + + @testing.combinations( + (True,), + (False,), + (None,), + (15,), + (0,), + (-1,), + (-1.0,), + (15.052,), + ("a string",), + (util.u("réve illé"),), + (util.u("réve🐍 illé"),), + ) + def test_single_element_round_trip(self, element): + data_table = self.tables.data_table + data_element = element + with config.db.begin() as conn: + conn.execute( + data_table.insert(), + { + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + row = conn.execute( + select(data_table.c.data, data_table.c.nulldata) + ).first() + + eq_(row, (data_element, data_element)) + + def test_round_trip_custom_json(self): + data_table = self.tables.data_table + data_element = {"key1": "data1"} + + js = mock.Mock(side_effect=json.dumps) + jd = mock.Mock(side_effect=json.loads) + engine = engines.testing_engine( + options=dict(json_serializer=js, json_deserializer=jd) + ) + + # support sqlite :memory: database... + data_table.create(engine, checkfirst=True) + with engine.begin() as conn: + conn.execute( + data_table.insert(), {"name": "row1", "data": data_element} + ) + row = conn.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + eq_(js.mock_calls, [mock.call(data_element)]) + eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) + + @testing.combinations( + ("parameters",), + ("multiparameters",), + ("values",), + ("omit",), + argnames="insert_type", + ) + def test_round_trip_none_as_sql_null(self, connection, insert_type): + col = self.tables.data_table.c["nulldata"] + + conn = connection + + if insert_type == "parameters": + stmt, params = self.tables.data_table.insert(), { + "name": "r1", + "nulldata": None, + "data": None, + } + elif insert_type == "multiparameters": + stmt, params = self.tables.data_table.insert(), [ + {"name": "r1", "nulldata": None, "data": None} + ] + elif insert_type == "values": + stmt, params = ( + self.tables.data_table.insert().values( + name="r1", + nulldata=None, + data=None, + ), + {}, + ) + elif insert_type == "omit": + stmt, params = ( + self.tables.data_table.insert(), + {"name": "r1", "data": None}, + ) + + else: + assert False + + conn.execute(stmt, params) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where(col.is_(null())) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + def test_round_trip_json_null_as_json_null(self, connection): + col = self.tables.data_table.c["data"] + + conn = connection + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": JSON.NULL}, + ) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where( + cast(col, String) == "null" + ) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + @testing.combinations( + ("parameters",), + ("multiparameters",), + ("values",), + argnames="insert_type", + ) + def test_round_trip_none_as_json_null(self, connection, insert_type): + col = self.tables.data_table.c["data"] + + if insert_type == "parameters": + stmt, params = self.tables.data_table.insert(), { + "name": "r1", + "data": None, + } + elif insert_type == "multiparameters": + stmt, params = self.tables.data_table.insert(), [ + {"name": "r1", "data": None} + ] + elif insert_type == "values": + stmt, params = ( + self.tables.data_table.insert().values(name="r1", data=None), + {}, + ) + else: + assert False + + conn = connection + conn.execute(stmt, params) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where( + cast(col, String) == "null" + ) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + def test_unicode_round_trip(self): + # note we include Unicode supplementary characters as well + with config.db.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + { + "name": "r1", + "data": { + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, + }, + }, + ) + + eq_( + conn.scalar(select(self.tables.data_table.c.data)), + { + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, + }, + ) + + def test_eval_none_flag_orm(self, connection): + + Base = declarative_base() + + class Data(Base): + __table__ = self.tables.data_table + + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() + + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), + ) + + +class JSONLegacyStringCastIndexTest( + _LiteralRoundTripFixture, fixtures.TablesTest +): + """test JSON index access with "cast to string", which we have documented + for a long time as how to compare JSON values, but is ultimately not + reliable in all cases. The "as_XYZ()" comparators should be used + instead. + + """ + + __requires__ = ("json_type", "legacy_unconditional_json_extract") + __backend__ = True + + datatype = JSON + + data1 = {"key1": "value1", "key2": "value2"} + + data2 = { + "Key 'One'": "value1", + "key two": "value2", + "key three": "value ' three '", + } + + data3 = { + "key1": [1, 2, 3], + "key2": ["one", "two", "three"], + "key3": [{"four": "five"}, {"six": "seven"}], + } + + data4 = ["one", "two", "three"] + + data5 = { + "nested": { + "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}], + "elem2": {"elem3": {"elem4": "elem5"}}, + } + } + + data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}} + + @classmethod + def define_tables(cls, metadata): + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype), + Column("nulldata", cls.datatype(none_as_null=True)), + ) + + def _criteria_fixture(self): + with config.db.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + [ + {"name": "r1", "data": self.data1}, + {"name": "r2", "data": self.data2}, + {"name": "r3", "data": self.data3}, + {"name": "r4", "data": self.data4}, + {"name": "r5", "data": self.data5}, + {"name": "r6", "data": self.data6}, + ], + ) + + def _test_index_criteria(self, crit, expected, test_literal=True): + self._criteria_fixture() + with config.db.connect() as conn: + stmt = select(self.tables.data_table.c.name).where(crit) + + eq_(conn.scalar(stmt), expected) + + if test_literal: + literal_sql = str( + stmt.compile( + config.db, compile_kwargs={"literal_binds": True} + ) + ) + + eq_(conn.exec_driver_sql(literal_sql).scalar(), expected) + + def test_string_cast_crit_spaces_in_key(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + # limit the rows here to avoid PG error + # "cannot extract field from a non-object", which is + # fixed in 9.4 but may exist in 9.3 + self._test_index_criteria( + and_( + name.in_(["r1", "r2", "r3"]), + cast(col["key two"], String) == '"value2"', + ), + "r2", + ) + + @config.requirements.json_array_indexes + def test_string_cast_crit_simple_int(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + # limit the rows here to avoid PG error + # "cannot extract array element from a non-array", which is + # fixed in 9.4 but may exist in 9.3 + self._test_index_criteria( + and_( + name == "r4", + cast(col[1], String) == '"two"', + ), + "r4", + ) + + def test_string_cast_crit_mixed_path(self): + col = self.tables.data_table.c["data"] + self._test_index_criteria( + cast(col[("key3", 1, "six")], String) == '"seven"', + "r3", + ) + + def test_string_cast_crit_string_path(self): + col = self.tables.data_table.c["data"] + self._test_index_criteria( + cast(col[("nested", "elem2", "elem3", "elem4")], String) + == '"elem5"', + "r5", + ) + + def test_string_cast_crit_against_string_basic(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + self._test_index_criteria( + and_( + name == "r6", + cast(col["b"], String) == '"some value"', + ), + "r6", + ) + + +__all__ = ( + "BinaryTest", + "UnicodeVarcharTest", + "UnicodeTextTest", + "JSONTest", + "JSONLegacyStringCastIndexTest", + "DateTest", + "DateTimeTest", + "DateTimeTZTest", + "TextTest", + "NumericTest", + "IntegerTest", + "CastTypeDecoratorTest", + "DateTimeHistoricTest", + "DateTimeCoercedToDateTimeTest", + "TimeMicrosecondsTest", + "TimestampMicrosecondsTest", + "TimeTest", + "TimeTZTest", + "DateTimeMicrosecondsTest", + "DateHistoricTest", + "StringTest", + "BooleanTest", +) diff --git a/lib/sqlalchemy/testing/suite/test_unicode_ddl.py b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py new file mode 100644 index 0000000..a4ae334 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py @@ -0,0 +1,206 @@ +# coding: utf-8 +"""verrrrry basic unicode column name testing""" + +from sqlalchemy import desc +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import testing +from sqlalchemy import util +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table +from sqlalchemy.util import u +from sqlalchemy.util import ue + + +class UnicodeSchemaTest(fixtures.TablesTest): + __requires__ = ("unicode_ddl",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + global t1, t2, t3 + + t1 = Table( + u("unitable1"), + metadata, + Column(u("méil"), Integer, primary_key=True), + Column(ue("\u6e2c\u8a66"), Integer), + test_needs_fk=True, + ) + t2 = Table( + u("Unitéble2"), + metadata, + Column(u("méil"), Integer, primary_key=True, key="a"), + Column( + ue("\u6e2c\u8a66"), + Integer, + ForeignKey(u("unitable1.méil")), + key="b", + ), + test_needs_fk=True, + ) + + # Few DBs support Unicode foreign keys + if testing.against("sqlite"): + t3 = Table( + ue("\u6e2c\u8a66"), + metadata, + Column( + ue("\u6e2c\u8a66_id"), + Integer, + primary_key=True, + autoincrement=False, + ), + Column( + ue("unitable1_\u6e2c\u8a66"), + Integer, + ForeignKey(ue("unitable1.\u6e2c\u8a66")), + ), + Column( + u("Unitéble2_b"), Integer, ForeignKey(u("Unitéble2.b")) + ), + Column( + ue("\u6e2c\u8a66_self"), + Integer, + ForeignKey(ue("\u6e2c\u8a66.\u6e2c\u8a66_id")), + ), + test_needs_fk=True, + ) + else: + t3 = Table( + ue("\u6e2c\u8a66"), + metadata, + Column( + ue("\u6e2c\u8a66_id"), + Integer, + primary_key=True, + autoincrement=False, + ), + Column(ue("unitable1_\u6e2c\u8a66"), Integer), + Column(u("Unitéble2_b"), Integer), + Column(ue("\u6e2c\u8a66_self"), Integer), + test_needs_fk=True, + ) + + def test_insert(self, connection): + connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + connection.execute(t2.insert(), {u("a"): 1, u("b"): 1}) + connection.execute( + t3.insert(), + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + }, + ) + + eq_(connection.execute(t1.select()).fetchall(), [(1, 5)]) + eq_(connection.execute(t2.select()).fetchall(), [(1, 1)]) + eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)]) + + def test_col_targeting(self, connection): + connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + connection.execute(t2.insert(), {u("a"): 1, u("b"): 1}) + connection.execute( + t3.insert(), + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + }, + ) + + row = connection.execute(t1.select()).first() + eq_(row._mapping[t1.c[u("méil")]], 1) + eq_(row._mapping[t1.c[ue("\u6e2c\u8a66")]], 5) + + row = connection.execute(t2.select()).first() + eq_(row._mapping[t2.c[u("a")]], 1) + eq_(row._mapping[t2.c[u("b")]], 1) + + row = connection.execute(t3.select()).first() + eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_id")]], 1) + eq_(row._mapping[t3.c[ue("unitable1_\u6e2c\u8a66")]], 5) + eq_(row._mapping[t3.c[u("Unitéble2_b")]], 1) + eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_self")]], 1) + + def test_reflect(self, connection): + connection.execute(t1.insert(), {u("méil"): 2, ue("\u6e2c\u8a66"): 7}) + connection.execute(t2.insert(), {u("a"): 2, u("b"): 2}) + connection.execute( + t3.insert(), + { + ue("\u6e2c\u8a66_id"): 2, + ue("unitable1_\u6e2c\u8a66"): 7, + u("Unitéble2_b"): 2, + ue("\u6e2c\u8a66_self"): 2, + }, + ) + + meta = MetaData() + tt1 = Table(t1.name, meta, autoload_with=connection) + tt2 = Table(t2.name, meta, autoload_with=connection) + tt3 = Table(t3.name, meta, autoload_with=connection) + + connection.execute(tt1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + connection.execute(tt2.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 1}) + connection.execute( + tt3.insert(), + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + }, + ) + + eq_( + connection.execute( + tt1.select().order_by(desc(u("méil"))) + ).fetchall(), + [(2, 7), (1, 5)], + ) + eq_( + connection.execute( + tt2.select().order_by(desc(u("méil"))) + ).fetchall(), + [(2, 2), (1, 1)], + ) + eq_( + connection.execute( + tt3.select().order_by(desc(ue("\u6e2c\u8a66_id"))) + ).fetchall(), + [(2, 7, 2, 2), (1, 5, 1, 1)], + ) + + def test_repr(self): + meta = MetaData() + t = Table( + ue("\u6e2c\u8a66"), meta, Column(ue("\u6e2c\u8a66_id"), Integer) + ) + + if util.py2k: + eq_( + repr(t), + ( + "Table('\\u6e2c\\u8a66', MetaData(), " + "Column('\\u6e2c\\u8a66_id', Integer(), " + "table=<\u6e2c\u8a66>), " + "schema=None)" + ), + ) + else: + eq_( + repr(t), + ( + "Table('測試', MetaData(), " + "Column('測試_id', Integer(), " + "table=<測試>), " + "schema=None)" + ), + ) diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py new file mode 100644 index 0000000..f04a9d5 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -0,0 +1,60 @@ +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import String + + +class SimpleUpdateDeleteTest(fixtures.TablesTest): + run_deletes = "each" + __requires__ = ("sane_rowcount",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute( + t.update().where(t.c.id == 2), dict(data="d2_new") + ) + assert not r.is_insert + assert not r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + assert not r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +__all__ = ("SimpleUpdateDeleteTest",) diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py new file mode 100644 index 0000000..be89bc6 --- /dev/null +++ b/lib/sqlalchemy/testing/util.py @@ -0,0 +1,458 @@ +# testing/util.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import decimal +import gc +import random +import sys +import types + +from . import config +from . import mock +from .. import inspect +from ..engine import Connection +from ..schema import Column +from ..schema import DropConstraint +from ..schema import DropTable +from ..schema import ForeignKeyConstraint +from ..schema import MetaData +from ..schema import Table +from ..sql import schema +from ..sql.sqltypes import Integer +from ..util import decorator +from ..util import defaultdict +from ..util import has_refcount_gc +from ..util import inspect_getfullargspec +from ..util import py2k + + +if not has_refcount_gc: + + def non_refcount_gc_collect(*args): + gc.collect() + gc.collect() + + gc_collect = lazy_gc = non_refcount_gc_collect +else: + # assume CPython - straight gc.collect, lazy_gc() is a pass + gc_collect = gc.collect + + def lazy_gc(): + pass + + +def picklers(): + picklers = set() + if py2k: + try: + import cPickle + + picklers.add(cPickle) + except ImportError: + pass + + import pickle + + picklers.add(pickle) + + # yes, this thing needs this much testing + for pickle_ in picklers: + for protocol in range(-2, pickle.HIGHEST_PROTOCOL): + yield pickle_.loads, lambda d: pickle_.dumps(d, protocol) + + +if py2k: + + def random_choices(population, k=1): + pop = list(population) + # lame but works :) + random.shuffle(pop) + return pop[0:k] + + +else: + + def random_choices(population, k=1): + return random.choices(population, k=k) + + +def round_decimal(value, prec): + if isinstance(value, float): + return round(value, prec) + + # can also use shift() here but that is 2.6 only + return (value * decimal.Decimal("1" + "0" * prec)).to_integral( + decimal.ROUND_FLOOR + ) / pow(10, prec) + + +class RandomSet(set): + def __iter__(self): + l = list(set.__iter__(self)) + random.shuffle(l) + return iter(l) + + def pop(self): + index = random.randint(0, len(self) - 1) + item = list(set.__iter__(self))[index] + self.remove(item) + return item + + def union(self, other): + return RandomSet(set.union(self, other)) + + def difference(self, other): + return RandomSet(set.difference(self, other)) + + def intersection(self, other): + return RandomSet(set.intersection(self, other)) + + def copy(self): + return RandomSet(self) + + +def conforms_partial_ordering(tuples, sorted_elements): + """True if the given sorting conforms to the given partial ordering.""" + + deps = defaultdict(set) + for parent, child in tuples: + deps[parent].add(child) + for i, node in enumerate(sorted_elements): + for n in sorted_elements[i:]: + if node in deps[n]: + return False + else: + return True + + +def all_partial_orderings(tuples, elements): + edges = defaultdict(set) + for parent, child in tuples: + edges[child].add(parent) + + def _all_orderings(elements): + + if len(elements) == 1: + yield list(elements) + else: + for elem in elements: + subset = set(elements).difference([elem]) + if not subset.intersection(edges[elem]): + for sub_ordering in _all_orderings(subset): + yield [elem] + sub_ordering + + return iter(_all_orderings(elements)) + + +def function_named(fn, name): + """Return a function with a given __name__. + + Will assign to __name__ and return the original function if possible on + the Python implementation, otherwise a new function will be constructed. + + This function should be phased out as much as possible + in favor of @decorator. Tests that "generate" many named tests + should be modernized. + + """ + try: + fn.__name__ = name + except TypeError: + fn = types.FunctionType( + fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__ + ) + return fn + + +def run_as_contextmanager(ctx, fn, *arg, **kw): + """Run the given function under the given contextmanager, + simulating the behavior of 'with' to support older + Python versions. + + This is not necessary anymore as we have placed 2.6 + as minimum Python version, however some tests are still using + this structure. + + """ + + obj = ctx.__enter__() + try: + result = fn(obj, *arg, **kw) + ctx.__exit__(None, None, None) + return result + except: + exc_info = sys.exc_info() + raise_ = ctx.__exit__(*exc_info) + if not raise_: + raise + else: + return raise_ + + +def rowset(results): + """Converts the results of sql execution into a plain set of column tuples. + + Useful for asserting the results of an unordered query. + """ + + return {tuple(row) for row in results} + + +def fail(msg): + assert False, msg + + +@decorator +def provide_metadata(fn, *args, **kw): + """Provide bound MetaData for a single test, dropping afterwards. + + Legacy; use the "metadata" pytest fixture. + + """ + + from . import fixtures + + metadata = schema.MetaData() + self = args[0] + prev_meta = getattr(self, "metadata", None) + self.metadata = metadata + try: + return fn(*args, **kw) + finally: + # close out some things that get in the way of dropping tables. + # when using the "metadata" fixture, there is a set ordering + # of things that makes sure things are cleaned up in order, however + # the simple "decorator" nature of this legacy function means + # we have to hardcode some of that cleanup ahead of time. + + # close ORM sessions + fixtures._close_all_sessions() + + # integrate with the "connection" fixture as there are many + # tests where it is used along with provide_metadata + if fixtures._connection_fixture_connection: + # TODO: this warning can be used to find all the places + # this is used with connection fixture + # warn("mixing legacy provide metadata with connection fixture") + drop_all_tables_from_metadata( + metadata, fixtures._connection_fixture_connection + ) + # as the provide_metadata fixture is often used with "testing.db", + # when we do the drop we have to commit the transaction so that + # the DB is actually updated as the CREATE would have been + # committed + fixtures._connection_fixture_connection.get_transaction().commit() + else: + drop_all_tables_from_metadata(metadata, config.db) + self.metadata = prev_meta + + +def flag_combinations(*combinations): + """A facade around @testing.combinations() oriented towards boolean + keyword-based arguments. + + Basically generates a nice looking identifier based on the keywords + and also sets up the argument names. + + E.g.:: + + @testing.flag_combinations( + dict(lazy=False, passive=False), + dict(lazy=True, passive=False), + dict(lazy=False, passive=True), + dict(lazy=False, passive=True, raiseload=True), + ) + + + would result in:: + + @testing.combinations( + ('', False, False, False), + ('lazy', True, False, False), + ('lazy_passive', True, True, False), + ('lazy_passive', True, True, True), + id_='iaaa', + argnames='lazy,passive,raiseload' + ) + + """ + + keys = set() + + for d in combinations: + keys.update(d) + + keys = sorted(keys) + + return config.combinations( + *[ + ("_".join(k for k in keys if d.get(k, False)),) + + tuple(d.get(k, False) for k in keys) + for d in combinations + ], + id_="i" + ("a" * len(keys)), + argnames=",".join(keys) + ) + + +def lambda_combinations(lambda_arg_sets, **kw): + args = inspect_getfullargspec(lambda_arg_sets) + + arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]]) + + def create_fixture(pos): + def fixture(**kw): + return lambda_arg_sets(**kw)[pos] + + fixture.__name__ = "fixture_%3.3d" % pos + return fixture + + return config.combinations( + *[(create_fixture(i),) for i in range(len(arg_sets))], **kw + ) + + +def resolve_lambda(__fn, **kw): + """Given a no-arg lambda and a namespace, return a new lambda that + has all the values filled in. + + This is used so that we can have module-level fixtures that + refer to instance-level variables using lambdas. + + """ + + pos_args = inspect_getfullargspec(__fn)[0] + pass_pos_args = {arg: kw.pop(arg) for arg in pos_args} + glb = dict(__fn.__globals__) + glb.update(kw) + new_fn = types.FunctionType(__fn.__code__, glb) + return new_fn(**pass_pos_args) + + +def metadata_fixture(ddl="function"): + """Provide MetaData for a pytest fixture.""" + + def decorate(fn): + def run_ddl(self): + + metadata = self.metadata = schema.MetaData() + try: + result = fn(self, metadata) + metadata.create_all(config.db) + # TODO: + # somehow get a per-function dml erase fixture here + yield result + finally: + metadata.drop_all(config.db) + + return config.fixture(scope=ddl)(run_ddl) + + return decorate + + +def force_drop_names(*names): + """Force the given table names to be dropped after test complete, + isolating for foreign key cycles + + """ + + @decorator + def go(fn, *args, **kw): + + try: + return fn(*args, **kw) + finally: + drop_all_tables(config.db, inspect(config.db), include_names=names) + + return go + + +class adict(dict): + """Dict keys available as attributes. Shadows.""" + + def __getattribute__(self, key): + try: + return self[key] + except KeyError: + return dict.__getattribute__(self, key) + + def __call__(self, *keys): + return tuple([self[key] for key in keys]) + + get_all = __call__ + + +def drop_all_tables_from_metadata(metadata, engine_or_connection): + from . import engines + + def go(connection): + engines.testing_reaper.prepare_for_drop_tables(connection) + + if not connection.dialect.supports_alter: + from . import assertions + + with assertions.expect_warnings( + "Can't sort tables", assert_=False + ): + metadata.drop_all(connection) + else: + metadata.drop_all(connection) + + if not isinstance(engine_or_connection, Connection): + with engine_or_connection.begin() as connection: + go(connection) + else: + go(engine_or_connection) + + +def drop_all_tables(engine, inspector, schema=None, include_names=None): + + if include_names is not None: + include_names = set(include_names) + + with engine.begin() as conn: + for tname, fkcs in reversed( + inspector.get_sorted_table_and_fkc_names(schema=schema) + ): + if tname: + if include_names is not None and tname not in include_names: + continue + conn.execute( + DropTable(Table(tname, MetaData(), schema=schema)) + ) + elif fkcs: + if not engine.dialect.supports_alter: + continue + for tname, fkc in fkcs: + if ( + include_names is not None + and tname not in include_names + ): + continue + tb = Table( + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + conn.execute( + DropConstraint( + ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc) + ) + ) + + +def teardown_events(event_cls): + @decorator + def decorate(fn, *arg, **kw): + try: + return fn(*arg, **kw) + finally: + event_cls._clear() + + return decorate diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py new file mode 100644 index 0000000..3e78387 --- /dev/null +++ b/lib/sqlalchemy/testing/warnings.py @@ -0,0 +1,82 @@ +# testing/warnings.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +import warnings + +from . import assertions +from .. import exc as sa_exc +from ..util.langhelpers import _warnings_warn + + +class SATestSuiteWarning(Warning): + """warning for a condition detected during tests that is non-fatal + + Currently outside of SAWarning so that we can work around tools like + Alembic doing the wrong thing with warnings. + + """ + + +def warn_test_suite(message): + _warnings_warn(message, category=SATestSuiteWarning) + + +def setup_filters(): + """Set global warning behavior for the test suite.""" + + # TODO: at this point we can use the normal pytest warnings plugin, + # if we decide the test suite can be linked to pytest only + + origin = r"^(?:test|sqlalchemy)\..*" + + warnings.filterwarnings( + "ignore", category=sa_exc.SAPendingDeprecationWarning + ) + warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning) + warnings.filterwarnings("error", category=sa_exc.SAWarning) + + warnings.filterwarnings("always", category=SATestSuiteWarning) + + warnings.filterwarnings( + "error", category=DeprecationWarning, module=origin + ) + + # ignore things that are deprecated *as of* 2.0 :) + warnings.filterwarnings( + "ignore", + category=sa_exc.SADeprecationWarning, + message=r".*\(deprecated since: 2.0\)$", + ) + warnings.filterwarnings( + "ignore", + category=sa_exc.SADeprecationWarning, + message=r"^The (Sybase|firebird) dialect is deprecated and will be", + ) + + try: + import pytest + except ImportError: + pass + else: + warnings.filterwarnings( + "once", category=pytest.PytestDeprecationWarning, module=origin + ) + + +def assert_warnings(fn, warning_msgs, regex=False): + """Assert that each of the given warnings are emitted by fn. + + Deprecated. Please use assertions.expect_warnings(). + + """ + + with assertions._expect_warnings( + sa_exc.SAWarning, warning_msgs, regex=regex + ): + return fn() |