summaryrefslogtreecommitdiffstats
path: root/lib/sqlalchemy/testing
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing')
-rw-r--r--lib/sqlalchemy/testing/__init__.py86
-rw-r--r--lib/sqlalchemy/testing/assertions.py845
-rw-r--r--lib/sqlalchemy/testing/assertsql.py457
-rw-r--r--lib/sqlalchemy/testing/asyncio.py128
-rw-r--r--lib/sqlalchemy/testing/config.py209
-rw-r--r--lib/sqlalchemy/testing/engines.py465
-rw-r--r--lib/sqlalchemy/testing/entities.py111
-rw-r--r--lib/sqlalchemy/testing/exclusions.py465
-rw-r--r--lib/sqlalchemy/testing/fixtures.py870
-rw-r--r--lib/sqlalchemy/testing/mock.py32
-rw-r--r--lib/sqlalchemy/testing/pickleable.py151
-rw-r--r--lib/sqlalchemy/testing/plugin/__init__.py0
-rw-r--r--lib/sqlalchemy/testing/plugin/bootstrap.py54
-rw-r--r--lib/sqlalchemy/testing/plugin/plugin_base.py789
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py820
-rw-r--r--lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py112
-rw-r--r--lib/sqlalchemy/testing/profiling.py335
-rw-r--r--lib/sqlalchemy/testing/provision.py416
-rw-r--r--lib/sqlalchemy/testing/requirements.py1518
-rw-r--r--lib/sqlalchemy/testing/schema.py218
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py13
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py204
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py381
-rw-r--r--lib/sqlalchemy/testing/suite/test_deprecations.py145
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py361
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py367
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py1738
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py426
-rw-r--r--lib/sqlalchemy/testing/suite/test_rowcount.py165
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py1783
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py282
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py1508
-rw-r--r--lib/sqlalchemy/testing/suite/test_unicode_ddl.py206
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py60
-rw-r--r--lib/sqlalchemy/testing/util.py458
-rw-r--r--lib/sqlalchemy/testing/warnings.py82
36 files changed, 16260 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py
new file mode 100644
index 0000000..80d344f
--- /dev/null
+++ b/lib/sqlalchemy/testing/__init__.py
@@ -0,0 +1,86 @@
+# testing/__init__.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+from . import config
+from . import mock
+from .assertions import assert_raises
+from .assertions import assert_raises_context_ok
+from .assertions import assert_raises_message
+from .assertions import assert_raises_message_context_ok
+from .assertions import assert_warns
+from .assertions import assert_warns_message
+from .assertions import AssertsCompiledSQL
+from .assertions import AssertsExecutionResults
+from .assertions import ComparesTables
+from .assertions import emits_warning
+from .assertions import emits_warning_on
+from .assertions import eq_
+from .assertions import eq_ignore_whitespace
+from .assertions import eq_regex
+from .assertions import expect_deprecated
+from .assertions import expect_deprecated_20
+from .assertions import expect_raises
+from .assertions import expect_raises_message
+from .assertions import expect_warnings
+from .assertions import in_
+from .assertions import is_
+from .assertions import is_false
+from .assertions import is_instance_of
+from .assertions import is_none
+from .assertions import is_not
+from .assertions import is_not_
+from .assertions import is_not_none
+from .assertions import is_true
+from .assertions import le_
+from .assertions import ne_
+from .assertions import not_in
+from .assertions import not_in_
+from .assertions import startswith_
+from .assertions import uses_deprecated
+from .config import async_test
+from .config import combinations
+from .config import combinations_list
+from .config import db
+from .config import fixture
+from .config import requirements as requires
+from .exclusions import _is_excluded
+from .exclusions import _server_version
+from .exclusions import against as _against
+from .exclusions import db_spec
+from .exclusions import exclude
+from .exclusions import fails
+from .exclusions import fails_if
+from .exclusions import fails_on
+from .exclusions import fails_on_everything_except
+from .exclusions import future
+from .exclusions import only_if
+from .exclusions import only_on
+from .exclusions import skip
+from .exclusions import skip_if
+from .schema import eq_clause_element
+from .schema import eq_type_affinity
+from .util import adict
+from .util import fail
+from .util import flag_combinations
+from .util import force_drop_names
+from .util import lambda_combinations
+from .util import metadata_fixture
+from .util import provide_metadata
+from .util import resolve_lambda
+from .util import rowset
+from .util import run_as_contextmanager
+from .util import teardown_events
+from .warnings import assert_warnings
+from .warnings import warn_test_suite
+
+
+def against(*queries):
+ return _against(config._current, *queries)
+
+
+crashes = skip
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
new file mode 100644
index 0000000..9a3c06b
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -0,0 +1,845 @@
+# testing/assertions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import contextlib
+import re
+import sys
+import warnings
+
+from . import assertsql
+from . import config
+from . import engines
+from . import mock
+from .exclusions import db_spec
+from .util import fail
+from .. import exc as sa_exc
+from .. import schema
+from .. import sql
+from .. import types as sqltypes
+from .. import util
+from ..engine import default
+from ..engine import url
+from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util import compat
+from ..util import decorator
+
+
+def expect_warnings(*messages, **kw):
+ """Context manager which expects one or more warnings.
+
+ With no arguments, squelches all SAWarning and RemovedIn20Warning emitted via
+ sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
+ pass string expressions that will match selected warnings via regex;
+ all non-matching warnings are sent through.
+
+ The expect version **asserts** that the warnings were in fact seen.
+
+ Note that the test suite sets SAWarning warnings to raise exceptions.
+
+ """ # noqa
+ return _expect_warnings(
+ (sa_exc.RemovedIn20Warning, sa_exc.SAWarning), messages, **kw
+ )
+
+
+@contextlib.contextmanager
+def expect_warnings_on(db, *messages, **kw):
+ """Context manager which expects one or more warnings on specific
+ dialects.
+
+ The expect version **asserts** that the warnings were in fact seen.
+
+ """
+ spec = db_spec(db)
+
+ if isinstance(db, util.string_types) and not spec(config._current):
+ yield
+ else:
+ with expect_warnings(*messages, **kw):
+ yield
+
+
+def emits_warning(*messages):
+ """Decorator form of expect_warnings().
+
+ Note that emits_warning does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_warnings(assert_=False, *messages):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+def expect_deprecated(*messages, **kw):
+ return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
+
+
+def expect_deprecated_20(*messages, **kw):
+ return _expect_warnings(sa_exc.Base20DeprecationWarning, messages, **kw)
+
+
+def emits_warning_on(db, *messages):
+ """Mark a test as emitting a warning on a specific dialect.
+
+ With no arguments, squelches all SAWarning failures. Or pass one or more
+ strings; these will be matched to the root of the warning description by
+ warnings.filterwarnings().
+
+ Note that emits_warning_on does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_warnings_on(db, assert_=False, *messages):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+def uses_deprecated(*messages):
+ """Mark a test as immune from fatal deprecation warnings.
+
+ With no arguments, squelches all SADeprecationWarning failures.
+ Or pass one or more strings; these will be matched to the root
+ of the warning description by warnings.filterwarnings().
+
+ As a special case, you may pass a function name prefixed with //
+ and it will be re-written as needed to match the standard warning
+ verbiage emitted by the sqlalchemy.util.deprecated decorator.
+
+ Note that uses_deprecated does **not** assert that the warnings
+ were in fact seen.
+
+ """
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ with expect_deprecated(*messages, assert_=False):
+ return fn(*args, **kw)
+
+ return decorate
+
+
+_FILTERS = None
+_SEEN = None
+_EXC_CLS = None
+
+
+@contextlib.contextmanager
+def _expect_warnings(
+ exc_cls,
+ messages,
+ regex=True,
+ search_msg=False,
+ assert_=True,
+ py2konly=False,
+ raise_on_any_unexpected=False,
+ squelch_other_warnings=False,
+):
+
+ global _FILTERS, _SEEN, _EXC_CLS
+
+ if regex or search_msg:
+ filters = [re.compile(msg, re.I | re.S) for msg in messages]
+ else:
+ filters = list(messages)
+
+ if _FILTERS is not None:
+ # nested call; update _FILTERS and _SEEN, return. outer
+ # block will assert our messages
+ assert _SEEN is not None
+ assert _EXC_CLS is not None
+ _FILTERS.extend(filters)
+ _SEEN.update(filters)
+ _EXC_CLS += (exc_cls,)
+ yield
+ else:
+ seen = _SEEN = set(filters)
+ _FILTERS = filters
+ _EXC_CLS = (exc_cls,)
+
+ if raise_on_any_unexpected:
+
+ def real_warn(msg, *arg, **kw):
+ raise AssertionError("Got unexpected warning: %r" % msg)
+
+ else:
+ real_warn = warnings.warn
+
+ def our_warn(msg, *arg, **kw):
+
+ if isinstance(msg, _EXC_CLS):
+ exception = type(msg)
+ msg = str(msg)
+ elif arg:
+ exception = arg[0]
+ else:
+ exception = None
+
+ if not exception or not issubclass(exception, _EXC_CLS):
+ if not squelch_other_warnings:
+ return real_warn(msg, *arg, **kw)
+ else:
+ return
+
+ if not filters and not raise_on_any_unexpected:
+ return
+
+ for filter_ in filters:
+ if (
+ (search_msg and filter_.search(msg))
+ or (regex and filter_.match(msg))
+ or (not regex and filter_ == msg)
+ ):
+ seen.discard(filter_)
+ break
+ else:
+ if not squelch_other_warnings:
+ real_warn(msg, *arg, **kw)
+
+ with mock.patch("warnings.warn", our_warn), mock.patch(
+ "sqlalchemy.util.SQLALCHEMY_WARN_20", True
+ ), mock.patch(
+ "sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20", True
+ ), mock.patch(
+ "sqlalchemy.engine.row.LegacyRow._default_key_style", 2
+ ):
+ try:
+ yield
+ finally:
+ _SEEN = _FILTERS = _EXC_CLS = None
+
+ if assert_ and (not py2konly or not compat.py3k):
+ assert not seen, "Warnings were not seen: %s" % ", ".join(
+ "%r" % (s.pattern if regex else s) for s in seen
+ )
+
+
+def global_cleanup_assertions():
+ """Check things that have to be finalized at the end of a test suite.
+
+ Hardcoded at the moment, a modular system can be built here
+ to support things like PG prepared transactions, tables all
+ dropped, etc.
+
+ """
+ _assert_no_stray_pool_connections()
+
+
+def _assert_no_stray_pool_connections():
+ engines.testing_reaper.assert_all_closed()
+
+
+def eq_regex(a, b, msg=None):
+ assert re.match(b, a), msg or "%r !~ %r" % (a, b)
+
+
+def eq_(a, b, msg=None):
+ """Assert a == b, with repr messaging on failure."""
+ assert a == b, msg or "%r != %r" % (a, b)
+
+
+def ne_(a, b, msg=None):
+ """Assert a != b, with repr messaging on failure."""
+ assert a != b, msg or "%r == %r" % (a, b)
+
+
+def le_(a, b, msg=None):
+ """Assert a <= b, with repr messaging on failure."""
+ assert a <= b, msg or "%r != %r" % (a, b)
+
+
+def is_instance_of(a, b, msg=None):
+ assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
+
+
+def is_none(a, msg=None):
+ is_(a, None, msg=msg)
+
+
+def is_not_none(a, msg=None):
+ is_not(a, None, msg=msg)
+
+
+def is_true(a, msg=None):
+ is_(bool(a), True, msg=msg)
+
+
+def is_false(a, msg=None):
+ is_(bool(a), False, msg=msg)
+
+
+def is_(a, b, msg=None):
+ """Assert a is b, with repr messaging on failure."""
+ assert a is b, msg or "%r is not %r" % (a, b)
+
+
+def is_not(a, b, msg=None):
+ """Assert a is not b, with repr messaging on failure."""
+ assert a is not b, msg or "%r is %r" % (a, b)
+
+
+# deprecated. See #5429
+is_not_ = is_not
+
+
+def in_(a, b, msg=None):
+ """Assert a in b, with repr messaging on failure."""
+ assert a in b, msg or "%r not in %r" % (a, b)
+
+
+def not_in(a, b, msg=None):
+ """Assert a in not b, with repr messaging on failure."""
+ assert a not in b, msg or "%r is in %r" % (a, b)
+
+
+# deprecated. See #5429
+not_in_ = not_in
+
+
+def startswith_(a, fragment, msg=None):
+ """Assert a.startswith(fragment), with repr messaging on failure."""
+ assert a.startswith(fragment), msg or "%r does not start with %r" % (
+ a,
+ fragment,
+ )
+
+
+def eq_ignore_whitespace(a, b, msg=None):
+ a = re.sub(r"^\s+?|\n", "", a)
+ a = re.sub(r" {2,}", " ", a)
+ b = re.sub(r"^\s+?|\n", "", b)
+ b = re.sub(r" {2,}", " ", b)
+
+ assert a == b, msg or "%r != %r" % (a, b)
+
+
+def _assert_proper_exception_context(exception):
+ """assert that any exception we're catching does not have a __context__
+ without a __cause__, and that __suppress_context__ is never set.
+
+ Python 3 will report nested as exceptions as "during the handling of
+ error X, error Y occurred". That's not what we want to do. we want
+ these exceptions in a cause chain.
+
+ """
+
+ if not util.py3k:
+ return
+
+ if (
+ exception.__context__ is not exception.__cause__
+ and not exception.__suppress_context__
+ ):
+ assert False, (
+ "Exception %r was correctly raised but did not set a cause, "
+ "within context %r as its cause."
+ % (exception, exception.__context__)
+ )
+
+
+def assert_raises(except_cls, callable_, *args, **kw):
+ return _assert_raises(except_cls, callable_, args, kw, check_context=True)
+
+
+def assert_raises_context_ok(except_cls, callable_, *args, **kw):
+ return _assert_raises(except_cls, callable_, args, kw)
+
+
+def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
+ return _assert_raises(
+ except_cls, callable_, args, kwargs, msg=msg, check_context=True
+ )
+
+
+def assert_warns(except_cls, callable_, *args, **kwargs):
+ """legacy adapter function for functions that were previously using
+ assert_raises with SAWarning or similar.
+
+ has some workarounds to accommodate the fact that the callable completes
+ with this approach rather than stopping at the exception raise.
+
+
+ """
+ with _expect_warnings(except_cls, [".*"], squelch_other_warnings=True):
+ return callable_(*args, **kwargs)
+
+
+def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
+ """legacy adapter function for functions that were previously using
+ assert_raises with SAWarning or similar.
+
+ has some workarounds to accommodate the fact that the callable completes
+ with this approach rather than stopping at the exception raise.
+
+ Also uses regex.search() to match the given message to the error string
+ rather than regex.match().
+
+ """
+ with _expect_warnings(
+ except_cls,
+ [msg],
+ search_msg=True,
+ regex=False,
+ squelch_other_warnings=True,
+ ):
+ return callable_(*args, **kwargs)
+
+
+def assert_raises_message_context_ok(
+ except_cls, msg, callable_, *args, **kwargs
+):
+ return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
+
+
+def _assert_raises(
+ except_cls, callable_, args, kwargs, msg=None, check_context=False
+):
+
+ with _expect_raises(except_cls, msg, check_context) as ec:
+ callable_(*args, **kwargs)
+ return ec.error
+
+
+class _ErrorContainer(object):
+ error = None
+
+
+@contextlib.contextmanager
+def _expect_raises(except_cls, msg=None, check_context=False):
+ if (
+ isinstance(except_cls, type)
+ and issubclass(except_cls, Warning)
+ or isinstance(except_cls, Warning)
+ ):
+ raise TypeError(
+ "Use expect_warnings for warnings, not "
+ "expect_raises / assert_raises"
+ )
+ ec = _ErrorContainer()
+ if check_context:
+ are_we_already_in_a_traceback = sys.exc_info()[0]
+ try:
+ yield ec
+ success = False
+ except except_cls as err:
+ ec.error = err
+ success = True
+ if msg is not None:
+ assert re.search(
+ msg, util.text_type(err), re.UNICODE
+ ), "%r !~ %s" % (msg, err)
+ if check_context and not are_we_already_in_a_traceback:
+ _assert_proper_exception_context(err)
+ print(util.text_type(err).encode("utf-8"))
+
+ # it's generally a good idea to not carry traceback objects outside
+ # of the except: block, but in this case especially we seem to have
+ # hit some bug in either python 3.10.0b2 or greenlet or both which
+ # this seems to fix:
+ # https://github.com/python-greenlet/greenlet/issues/242
+ del ec
+
+ # assert outside the block so it works for AssertionError too !
+ assert success, "Callable did not raise an exception"
+
+
+def expect_raises(except_cls, check_context=True):
+ return _expect_raises(except_cls, check_context=check_context)
+
+
+def expect_raises_message(except_cls, msg, check_context=True):
+ return _expect_raises(except_cls, msg=msg, check_context=check_context)
+
+
+class AssertsCompiledSQL(object):
+ def assert_compile(
+ self,
+ clause,
+ result,
+ params=None,
+ checkparams=None,
+ for_executemany=False,
+ check_literal_execute=None,
+ check_post_param=None,
+ dialect=None,
+ checkpositional=None,
+ check_prefetch=None,
+ use_default_dialect=False,
+ allow_dialect_select=False,
+ supports_default_values=True,
+ supports_default_metavalue=True,
+ literal_binds=False,
+ render_postcompile=False,
+ schema_translate_map=None,
+ render_schema_translate=False,
+ default_schema_name=None,
+ from_linting=False,
+ ):
+ if use_default_dialect:
+ dialect = default.DefaultDialect()
+ dialect.supports_default_values = supports_default_values
+ dialect.supports_default_metavalue = supports_default_metavalue
+ elif allow_dialect_select:
+ dialect = None
+ else:
+ if dialect is None:
+ dialect = getattr(self, "__dialect__", None)
+
+ if dialect is None:
+ dialect = config.db.dialect
+ elif dialect == "default":
+ dialect = default.DefaultDialect()
+ dialect.supports_default_values = supports_default_values
+ dialect.supports_default_metavalue = supports_default_metavalue
+ elif dialect == "default_enhanced":
+ dialect = default.StrCompileDialect()
+ elif isinstance(dialect, util.string_types):
+ dialect = url.URL.create(dialect).get_dialect()()
+
+ if default_schema_name:
+ dialect.default_schema_name = default_schema_name
+
+ kw = {}
+ compile_kwargs = {}
+
+ if schema_translate_map:
+ kw["schema_translate_map"] = schema_translate_map
+
+ if params is not None:
+ kw["column_keys"] = list(params)
+
+ if literal_binds:
+ compile_kwargs["literal_binds"] = True
+
+ if render_postcompile:
+ compile_kwargs["render_postcompile"] = True
+
+ if for_executemany:
+ kw["for_executemany"] = True
+
+ if render_schema_translate:
+ kw["render_schema_translate"] = True
+
+ if from_linting or getattr(self, "assert_from_linting", False):
+ kw["linting"] = sql.FROM_LINTING
+
+ from sqlalchemy import orm
+
+ if isinstance(clause, orm.Query):
+ stmt = clause._statement_20()
+ stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
+ clause = stmt
+
+ if compile_kwargs:
+ kw["compile_kwargs"] = compile_kwargs
+
+ class DontAccess(object):
+ def __getattribute__(self, key):
+ raise NotImplementedError(
+ "compiler accessed .statement; use "
+ "compiler.current_executable"
+ )
+
+ class CheckCompilerAccess(object):
+ def __init__(self, test_statement):
+ self.test_statement = test_statement
+ self._annotations = {}
+ self.supports_execution = getattr(
+ test_statement, "supports_execution", False
+ )
+
+ if self.supports_execution:
+ self._execution_options = test_statement._execution_options
+
+ if hasattr(test_statement, "_returning"):
+ self._returning = test_statement._returning
+ if hasattr(test_statement, "_inline"):
+ self._inline = test_statement._inline
+ if hasattr(test_statement, "_return_defaults"):
+ self._return_defaults = test_statement._return_defaults
+
+ def _default_dialect(self):
+ return self.test_statement._default_dialect()
+
+ def compile(self, dialect, **kw):
+ return self.test_statement.compile.__func__(
+ self, dialect=dialect, **kw
+ )
+
+ def _compiler(self, dialect, **kw):
+ return self.test_statement._compiler.__func__(
+ self, dialect, **kw
+ )
+
+ def _compiler_dispatch(self, compiler, **kwargs):
+ if hasattr(compiler, "statement"):
+ with mock.patch.object(
+ compiler, "statement", DontAccess()
+ ):
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+ else:
+ return self.test_statement._compiler_dispatch(
+ compiler, **kwargs
+ )
+
+ # no construct can assume it's the "top level" construct in all cases
+ # as anything can be nested. ensure constructs don't assume they
+ # are the "self.statement" element
+ c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
+
+ if isinstance(clause, sqltypes.TypeEngine):
+ cache_key_no_warnings = clause._static_cache_key
+ if cache_key_no_warnings:
+ hash(cache_key_no_warnings)
+ else:
+ cache_key_no_warnings = clause._generate_cache_key()
+ if cache_key_no_warnings:
+ hash(cache_key_no_warnings[0])
+
+ param_str = repr(getattr(c, "params", {}))
+ if util.py3k:
+ param_str = param_str.encode("utf-8").decode("ascii", "ignore")
+ print(
+ ("\nSQL String:\n" + util.text_type(c) + param_str).encode(
+ "utf-8"
+ )
+ )
+ else:
+ print(
+ "\nSQL String:\n"
+ + util.text_type(c).encode("utf-8")
+ + param_str
+ )
+
+ cc = re.sub(r"[\n\t]", "", util.text_type(c))
+
+ eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
+
+ if checkparams is not None:
+ eq_(c.construct_params(params), checkparams)
+ if checkpositional is not None:
+ p = c.construct_params(params)
+ eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
+ if check_prefetch is not None:
+ eq_(c.prefetch, check_prefetch)
+ if check_literal_execute is not None:
+ eq_(
+ {
+ c.bind_names[b]: b.effective_value
+ for b in c.literal_execute_params
+ },
+ check_literal_execute,
+ )
+ if check_post_param is not None:
+ eq_(
+ {
+ c.bind_names[b]: b.effective_value
+ for b in c.post_compile_params
+ },
+ check_post_param,
+ )
+
+
+class ComparesTables(object):
+ def assert_tables_equal(self, table, reflected_table, strict_types=False):
+ assert len(table.c) == len(reflected_table.c)
+ for c, reflected_c in zip(table.c, reflected_table.c):
+ eq_(c.name, reflected_c.name)
+ assert reflected_c is reflected_table.c[c.name]
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
+
+ if strict_types:
+ msg = "Type '%s' doesn't correspond to type '%s'"
+ assert isinstance(reflected_c.type, type(c.type)), msg % (
+ reflected_c.type,
+ c.type,
+ )
+ else:
+ self.assert_types_base(reflected_c, c)
+
+ if isinstance(c.type, sqltypes.String):
+ eq_(c.type.length, reflected_c.type.length)
+
+ eq_(
+ {f.column.name for f in c.foreign_keys},
+ {f.column.name for f in reflected_c.foreign_keys},
+ )
+ if c.server_default:
+ assert isinstance(
+ reflected_c.server_default, schema.FetchedValue
+ )
+
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name] is not None
+
+ def assert_types_base(self, c1, c2):
+ assert c1.type._compare_type_affinity(
+ c2.type
+ ), "On column %r, type '%s' doesn't correspond to type '%s'" % (
+ c1.name,
+ c1.type,
+ c2.type,
+ )
+
+
+class AssertsExecutionResults(object):
+ def assert_result(self, result, class_, *objects):
+ result = list(result)
+ print(repr(result))
+ self.assert_list(result, class_, objects)
+
+ def assert_list(self, result, class_, list_):
+ self.assert_(
+ len(result) == len(list_),
+ "result list is not the same size as test list, "
+ + "for class "
+ + class_.__name__,
+ )
+ for i in range(0, len(list_)):
+ self.assert_row(class_, result[i], list_[i])
+
+ def assert_row(self, class_, rowobj, desc):
+ self.assert_(
+ rowobj.__class__ is class_, "item class is not " + repr(class_)
+ )
+ for key, value in desc.items():
+ if isinstance(value, tuple):
+ if isinstance(value[1], list):
+ self.assert_list(getattr(rowobj, key), value[0], value[1])
+ else:
+ self.assert_row(value[0], getattr(rowobj, key), value[1])
+ else:
+ self.assert_(
+ getattr(rowobj, key) == value,
+ "attribute %s value %s does not match %s"
+ % (key, getattr(rowobj, key), value),
+ )
+
+ def assert_unordered_result(self, result, cls, *expected):
+ """As assert_result, but the order of objects is not considered.
+
+ The algorithm is very expensive but not a big deal for the small
+ numbers of rows that the test suite manipulates.
+ """
+
+ class immutabledict(dict):
+ def __hash__(self):
+ return id(self)
+
+ found = util.IdentitySet(result)
+ expected = {immutabledict(e) for e in expected}
+
+ for wrong in util.itertools_filterfalse(
+ lambda o: isinstance(o, cls), found
+ ):
+ fail(
+ 'Unexpected type "%s", expected "%s"'
+ % (type(wrong).__name__, cls.__name__)
+ )
+
+ if len(found) != len(expected):
+ fail(
+ 'Unexpected object count "%s", expected "%s"'
+ % (len(found), len(expected))
+ )
+
+ NOVALUE = object()
+
+ def _compare_item(obj, spec):
+ for key, value in spec.items():
+ if isinstance(value, tuple):
+ try:
+ self.assert_unordered_result(
+ getattr(obj, key), value[0], *value[1]
+ )
+ except AssertionError:
+ return False
+ else:
+ if getattr(obj, key, NOVALUE) != value:
+ return False
+ return True
+
+ for expected_item in expected:
+ for found_item in found:
+ if _compare_item(found_item, expected_item):
+ found.remove(found_item)
+ break
+ else:
+ fail(
+ "Expected %s instance with attributes %s not found."
+ % (cls.__name__, repr(expected_item))
+ )
+ return True
+
+ def sql_execution_asserter(self, db=None):
+ if db is None:
+ from . import db as db
+
+ return assertsql.assert_engine(db)
+
+ def assert_sql_execution(self, db, callable_, *rules):
+ with self.sql_execution_asserter(db) as asserter:
+ result = callable_()
+ asserter.assert_(*rules)
+ return result
+
+ def assert_sql(self, db, callable_, rules):
+
+ newrules = []
+ for rule in rules:
+ if isinstance(rule, dict):
+ newrule = assertsql.AllOf(
+ *[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
+ )
+ else:
+ newrule = assertsql.CompiledSQL(*rule)
+ newrules.append(newrule)
+
+ return self.assert_sql_execution(db, callable_, *newrules)
+
+ def assert_sql_count(self, db, callable_, count):
+ self.assert_sql_execution(
+ db, callable_, assertsql.CountStatements(count)
+ )
+
+ def assert_multiple_sql_count(self, dbs, callable_, counts):
+ recs = [
+ (self.sql_execution_asserter(db), db, count)
+ for (db, count) in zip(dbs, counts)
+ ]
+ asserters = []
+ for ctx, db, count in recs:
+ asserters.append(ctx.__enter__())
+ try:
+ return callable_()
+ finally:
+ for asserter, (ctx, db, count) in zip(asserters, recs):
+ ctx.__exit__(None, None, None)
+ asserter.assert_(assertsql.CountStatements(count))
+
+ @contextlib.contextmanager
+ def assert_execution(self, db, *rules):
+ with self.sql_execution_asserter(db) as asserter:
+ yield
+ asserter.assert_(*rules)
+
+ def assert_statement_count(self, db, count):
+ return self.assert_execution(db, assertsql.CountStatements(count))
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
new file mode 100644
index 0000000..565b3ed
--- /dev/null
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -0,0 +1,457 @@
+# testing/assertsql.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import collections
+import contextlib
+import re
+
+from .. import event
+from .. import util
+from ..engine import url
+from ..engine.default import DefaultDialect
+from ..engine.util import _distill_cursor_params
+from ..schema import _DDLCompiles
+
+
+class AssertRule(object):
+
+ is_consumed = False
+ errormessage = None
+ consume_statement = True
+
+ def process_statement(self, execute_observed):
+ pass
+
+ def no_more_statements(self):
+ assert False, (
+ "All statements are complete, but pending "
+ "assertion rules remain"
+ )
+
+
+class SQLMatchRule(AssertRule):
+ pass
+
+
+class CursorSQL(SQLMatchRule):
+ def __init__(self, statement, params=None, consume_statement=True):
+ self.statement = statement
+ self.params = params
+ self.consume_statement = consume_statement
+
+ def process_statement(self, execute_observed):
+ stmt = execute_observed.statements[0]
+ if self.statement != stmt.statement or (
+ self.params is not None and self.params != stmt.parameters
+ ):
+ self.errormessage = (
+ "Testing for exact SQL %s parameters %s received %s %s"
+ % (
+ self.statement,
+ self.params,
+ stmt.statement,
+ stmt.parameters,
+ )
+ )
+ else:
+ execute_observed.statements.pop(0)
+ self.is_consumed = True
+ if not execute_observed.statements:
+ self.consume_statement = True
+
+
+class CompiledSQL(SQLMatchRule):
+ def __init__(self, statement, params=None, dialect="default"):
+ self.statement = statement
+ self.params = params
+ self.dialect = dialect
+
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r"[\n\t]", "", self.statement)
+ return received_statement == stmt
+
+ def _compile_dialect(self, execute_observed):
+ if self.dialect == "default":
+ dialect = DefaultDialect()
+ # this is currently what tests are expecting
+ # dialect.supports_default_values = True
+ dialect.supports_default_metavalue = True
+ return dialect
+ else:
+ # ugh
+ if self.dialect == "postgresql":
+ params = {"implicit_returning": True}
+ else:
+ params = {}
+ return url.URL.create(self.dialect).get_dialect()(**params)
+
+ def _received_statement(self, execute_observed):
+ """reconstruct the statement and params in terms
+ of a target dialect, which for CompiledSQL is just DefaultDialect."""
+
+ context = execute_observed.context
+ compare_dialect = self._compile_dialect(execute_observed)
+
+ # received_statement runs a full compile(). we should not need to
+ # consider extracted_parameters; if we do this indicates some state
+ # is being sent from a previous cached query, which some misbehaviors
+ # in the ORM can cause, see #6881
+ cache_key = None # execute_observed.context.compiled.cache_key
+ extracted_parameters = (
+ None # execute_observed.context.extracted_parameters
+ )
+
+ if "schema_translate_map" in context.execution_options:
+ map_ = context.execution_options["schema_translate_map"]
+ else:
+ map_ = None
+
+ if isinstance(execute_observed.clauseelement, _DDLCompiles):
+
+ compiled = execute_observed.clauseelement.compile(
+ dialect=compare_dialect,
+ schema_translate_map=map_,
+ )
+ else:
+ compiled = execute_observed.clauseelement.compile(
+ cache_key=cache_key,
+ dialect=compare_dialect,
+ column_keys=context.compiled.column_keys,
+ for_executemany=context.compiled.for_executemany,
+ schema_translate_map=map_,
+ )
+ _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
+ parameters = execute_observed.parameters
+
+ if not parameters:
+ _received_parameters = [
+ compiled.construct_params(
+ extracted_parameters=extracted_parameters
+ )
+ ]
+ else:
+ _received_parameters = [
+ compiled.construct_params(
+ m, extracted_parameters=extracted_parameters
+ )
+ for m in parameters
+ ]
+
+ return _received_statement, _received_parameters
+
+ def process_statement(self, execute_observed):
+ context = execute_observed.context
+
+ _received_statement, _received_parameters = self._received_statement(
+ execute_observed
+ )
+ params = self._all_params(context)
+
+ equivalent = self._compare_sql(execute_observed, _received_statement)
+
+ if equivalent:
+ if params is not None:
+ all_params = list(params)
+ all_received = list(_received_parameters)
+ while all_params and all_received:
+ param = dict(all_params.pop(0))
+
+ for idx, received in enumerate(list(all_received)):
+ # do a positive compare only
+ for param_key in param:
+ # a key in param did not match current
+ # 'received'
+ if (
+ param_key not in received
+ or received[param_key] != param[param_key]
+ ):
+ break
+ else:
+ # all keys in param matched 'received';
+ # onto next param
+ del all_received[idx]
+ break
+ else:
+ # param did not match any entry
+ # in all_received
+ equivalent = False
+ break
+ if all_params or all_received:
+ equivalent = False
+
+ if equivalent:
+ self.is_consumed = True
+ self.errormessage = None
+ else:
+ self.errormessage = self._failure_message(params) % {
+ "received_statement": _received_statement,
+ "received_parameters": _received_parameters,
+ }
+
+ def _all_params(self, context):
+ if self.params:
+ if callable(self.params):
+ params = self.params(context)
+ else:
+ params = self.params
+ if not isinstance(params, list):
+ params = [params]
+ return params
+ else:
+ return None
+
+ def _failure_message(self, expected_params):
+ return (
+ "Testing for compiled statement\n%r partial params %s, "
+ "received\n%%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (
+ self.statement.replace("%", "%%"),
+ repr(expected_params).replace("%", "%%"),
+ )
+ )
+
+
+class RegexSQL(CompiledSQL):
+ def __init__(self, regex, params=None, dialect="default"):
+ SQLMatchRule.__init__(self)
+ self.regex = re.compile(regex)
+ self.orig_regex = regex
+ self.params = params
+ self.dialect = dialect
+
+ def _failure_message(self, expected_params):
+ return (
+ "Testing for compiled statement ~%r partial params %s, "
+ "received %%(received_statement)r with params "
+ "%%(received_parameters)r"
+ % (
+ self.orig_regex.replace("%", "%%"),
+ repr(expected_params).replace("%", "%%"),
+ )
+ )
+
+ def _compare_sql(self, execute_observed, received_statement):
+ return bool(self.regex.match(received_statement))
+
+
+class DialectSQL(CompiledSQL):
+ def _compile_dialect(self, execute_observed):
+ return execute_observed.context.dialect
+
+ def _compare_no_space(self, real_stmt, received_stmt):
+ stmt = re.sub(r"[\n\t]", "", real_stmt)
+ return received_stmt == stmt
+
+ def _received_statement(self, execute_observed):
+ received_stmt, received_params = super(
+ DialectSQL, self
+ )._received_statement(execute_observed)
+
+ # TODO: why do we need this part?
+ for real_stmt in execute_observed.statements:
+ if self._compare_no_space(real_stmt.statement, received_stmt):
+ break
+ else:
+ raise AssertionError(
+ "Can't locate compiled statement %r in list of "
+ "statements actually invoked" % received_stmt
+ )
+
+ return received_stmt, execute_observed.context.compiled_parameters
+
+ def _compare_sql(self, execute_observed, received_statement):
+ stmt = re.sub(r"[\n\t]", "", self.statement)
+ # convert our comparison statement to have the
+ # paramstyle of the received
+ paramstyle = execute_observed.context.dialect.paramstyle
+ if paramstyle == "pyformat":
+ stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
+ else:
+ # positional params
+ repl = None
+ if paramstyle == "qmark":
+ repl = "?"
+ elif paramstyle == "format":
+ repl = r"%s"
+ elif paramstyle == "numeric":
+ repl = None
+ stmt = re.sub(r":([\w_]+)", repl, stmt)
+
+ return received_statement == stmt
+
+
+class CountStatements(AssertRule):
+ def __init__(self, count):
+ self.count = count
+ self._statement_count = 0
+
+ def process_statement(self, execute_observed):
+ self._statement_count += 1
+
+ def no_more_statements(self):
+ if self.count != self._statement_count:
+ assert False, "desired statement count %d does not match %d" % (
+ self.count,
+ self._statement_count,
+ )
+
+
+class AllOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = set(rules)
+
+ def process_statement(self, execute_observed):
+ for rule in list(self.rules):
+ rule.errormessage = None
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.discard(rule)
+ if not self.rules:
+ self.is_consumed = True
+ break
+ elif not rule.errormessage:
+ # rule is not done yet
+ self.errormessage = None
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
+
+
+class EachOf(AssertRule):
+ def __init__(self, *rules):
+ self.rules = list(rules)
+
+ def process_statement(self, execute_observed):
+ while self.rules:
+ rule = self.rules[0]
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.rules.pop(0)
+ elif rule.errormessage:
+ self.errormessage = rule.errormessage
+ if rule.consume_statement:
+ break
+
+ if not self.rules:
+ self.is_consumed = True
+
+ def no_more_statements(self):
+ if self.rules and not self.rules[0].is_consumed:
+ self.rules[0].no_more_statements()
+ elif self.rules:
+ super(EachOf, self).no_more_statements()
+
+
+class Conditional(EachOf):
+ def __init__(self, condition, rules, else_rules):
+ if condition:
+ super(Conditional, self).__init__(*rules)
+ else:
+ super(Conditional, self).__init__(*else_rules)
+
+
+class Or(AllOf):
+ def process_statement(self, execute_observed):
+ for rule in self.rules:
+ rule.process_statement(execute_observed)
+ if rule.is_consumed:
+ self.is_consumed = True
+ break
+ else:
+ self.errormessage = list(self.rules)[0].errormessage
+
+
+class SQLExecuteObserved(object):
+ def __init__(self, context, clauseelement, multiparams, params):
+ self.context = context
+ self.clauseelement = clauseelement
+ self.parameters = _distill_cursor_params(
+ context.connection, tuple(multiparams), params
+ )
+ self.statements = []
+
+ def __repr__(self):
+ return str(self.statements)
+
+
+class SQLCursorExecuteObserved(
+ collections.namedtuple(
+ "SQLCursorExecuteObserved",
+ ["statement", "parameters", "context", "executemany"],
+ )
+):
+ pass
+
+
+class SQLAsserter(object):
+ def __init__(self):
+ self.accumulated = []
+
+ def _close(self):
+ self._final = self.accumulated
+ del self.accumulated
+
+ def assert_(self, *rules):
+ rule = EachOf(*rules)
+
+ observed = list(self._final)
+ while observed:
+ statement = observed.pop(0)
+ rule.process_statement(statement)
+ if rule.is_consumed:
+ break
+ elif rule.errormessage:
+ assert False, rule.errormessage
+ if observed:
+ assert False, "Additional SQL statements remain:\n%s" % observed
+ elif not rule.is_consumed:
+ rule.no_more_statements()
+
+
+@contextlib.contextmanager
+def assert_engine(engine):
+ asserter = SQLAsserter()
+
+ orig = []
+
+ @event.listens_for(engine, "before_execute")
+ def connection_execute(
+ conn, clauseelement, multiparams, params, execution_options
+ ):
+ # grab the original statement + params before any cursor
+ # execution
+ orig[:] = clauseelement, multiparams, params
+
+ @event.listens_for(engine, "after_cursor_execute")
+ def cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ if not context:
+ return
+ # then grab real cursor statements and associate them all
+ # around a single context
+ if (
+ asserter.accumulated
+ and asserter.accumulated[-1].context is context
+ ):
+ obs = asserter.accumulated[-1]
+ else:
+ obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
+ asserter.accumulated.append(obs)
+ obs.statements.append(
+ SQLCursorExecuteObserved(
+ statement, parameters, context, executemany
+ )
+ )
+
+ try:
+ yield asserter
+ finally:
+ event.remove(engine, "after_cursor_execute", cursor_execute)
+ event.remove(engine, "before_execute", connection_execute)
+ asserter._close()
diff --git a/lib/sqlalchemy/testing/asyncio.py b/lib/sqlalchemy/testing/asyncio.py
new file mode 100644
index 0000000..2189060
--- /dev/null
+++ b/lib/sqlalchemy/testing/asyncio.py
@@ -0,0 +1,128 @@
+# testing/asyncio.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+# functions and wrappers to run tests, fixtures, provisioning and
+# setup/teardown in an asyncio event loop, conditionally based on the
+# current DB driver being used for a test.
+
+# note that SQLAlchemy's asyncio integration also supports a method
+# of running individual asyncio functions inside of separate event loops
+# using "async_fallback" mode; however running whole functions in the event
+# loop is a more accurate test for how SQLAlchemy's asyncio features
+# would run in the real world.
+
+
+from functools import wraps
+import inspect
+
+from . import config
+from ..util.concurrency import _util_async_run
+from ..util.concurrency import _util_async_run_coroutine_function
+
+# may be set to False if the
+# --disable-asyncio flag is passed to the test runner.
+ENABLE_ASYNCIO = True
+
+
+def _run_coroutine_function(fn, *args, **kwargs):
+ return _util_async_run_coroutine_function(fn, *args, **kwargs)
+
+
+def _assume_async(fn, *args, **kwargs):
+ """Run a function in an asyncio loop unconditionally.
+
+ This function is used for provisioning features like
+ testing a database connection for server info.
+
+ Note that for blocking IO database drivers, this means they block the
+ event loop.
+
+ """
+
+ if not ENABLE_ASYNCIO:
+ return fn(*args, **kwargs)
+
+ return _util_async_run(fn, *args, **kwargs)
+
+
+def _maybe_async_provisioning(fn, *args, **kwargs):
+ """Run a function in an asyncio loop if any current drivers might need it.
+
+ This function is used for provisioning features that take
+ place outside of a specific database driver being selected, so if the
+ current driver that happens to be used for the provisioning operation
+ is an async driver, it will run in asyncio and not fail.
+
+ Note that for blocking IO database drivers, this means they block the
+ event loop.
+
+ """
+ if not ENABLE_ASYNCIO:
+ return fn(*args, **kwargs)
+
+ if config.any_async:
+ return _util_async_run(fn, *args, **kwargs)
+ else:
+ return fn(*args, **kwargs)
+
+
+def _maybe_async(fn, *args, **kwargs):
+ """Run a function in an asyncio loop if the current selected driver is
+ async.
+
+ This function is used for test setup/teardown and tests themselves
+ where the current DB driver is known.
+
+
+ """
+ if not ENABLE_ASYNCIO:
+
+ return fn(*args, **kwargs)
+
+ is_async = config._current.is_async
+
+ if is_async:
+ return _util_async_run(fn, *args, **kwargs)
+ else:
+ return fn(*args, **kwargs)
+
+
+def _maybe_async_wrapper(fn):
+ """Apply the _maybe_async function to an existing function and return
+ as a wrapped callable, supporting generator functions as well.
+
+ This is currently used for pytest fixtures that support generator use.
+
+ """
+
+ if inspect.isgeneratorfunction(fn):
+ _stop = object()
+
+ def call_next(gen):
+ try:
+ return next(gen)
+ # can't raise StopIteration in an awaitable.
+ except StopIteration:
+ return _stop
+
+ @wraps(fn)
+ def wrap_fixture(*args, **kwargs):
+ gen = fn(*args, **kwargs)
+ while True:
+ value = _maybe_async(call_next, gen)
+ if value is _stop:
+ break
+ yield value
+
+ else:
+
+ @wraps(fn)
+ def wrap_fixture(*args, **kwargs):
+ return _maybe_async(fn, *args, **kwargs)
+
+ return wrap_fixture
diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py
new file mode 100644
index 0000000..fc13a16
--- /dev/null
+++ b/lib/sqlalchemy/testing/config.py
@@ -0,0 +1,209 @@
+# testing/config.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import collections
+
+from .. import util
+
+requirements = None
+db = None
+db_url = None
+db_opts = None
+file_config = None
+test_schema = None
+test_schema_2 = None
+any_async = False
+_current = None
+ident = "main"
+
+_fixture_functions = None # installed by plugin_base
+
+
+def combinations(*comb, **kw):
+ r"""Deliver multiple versions of a test based on positional combinations.
+
+ This is a facade over pytest.mark.parametrize.
+
+
+ :param \*comb: argument combinations. These are tuples that will be passed
+ positionally to the decorated function.
+
+ :param argnames: optional list of argument names. These are the names
+ of the arguments in the test function that correspond to the entries
+ in each argument tuple. pytest.mark.parametrize requires this, however
+ the combinations function will derive it automatically if not present
+ by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
+ first argument is "self" which is discarded.
+
+ :param id\_: optional id template. This is a string template that
+ describes how the "id" for each parameter set should be defined, if any.
+ The number of characters in the template should match the number of
+ entries in each argument tuple. Each character describes how the
+ corresponding entry in the argument tuple should be handled, as far as
+ whether or not it is included in the arguments passed to the function, as
+ well as if it is included in the tokens used to create the id of the
+ parameter set.
+
+ If omitted, the argument combinations are passed to parametrize as is. If
+ passed, each argument combination is turned into a pytest.param() object,
+ mapping the elements of the argument tuple to produce an id based on a
+ character value in the same position within the string template using the
+ following scheme::
+
+ i - the given argument is a string that is part of the id only, don't
+ pass it as an argument
+
+ n - the given argument should be passed and it should be added to the
+ id by calling the .__name__ attribute
+
+ r - the given argument should be passed and it should be added to the
+ id by calling repr()
+
+ s - the given argument should be passed and it should be added to the
+ id by calling str()
+
+ a - (argument) the given argument should be passed and it should not
+ be used to generated the id
+
+ e.g.::
+
+ @testing.combinations(
+ (operator.eq, "eq"),
+ (operator.ne, "ne"),
+ (operator.gt, "gt"),
+ (operator.lt, "lt"),
+ id_="na"
+ )
+ def test_operator(self, opfunc, name):
+ pass
+
+ The above combination will call ``.__name__`` on the first member of
+ each tuple and use that as the "id" to pytest.param().
+
+
+ """
+ return _fixture_functions.combinations(*comb, **kw)
+
+
+def combinations_list(arg_iterable, **kw):
+ "As combination, but takes a single iterable"
+ return combinations(*arg_iterable, **kw)
+
+
+def fixture(*arg, **kw):
+ return _fixture_functions.fixture(*arg, **kw)
+
+
+def get_current_test_name():
+ return _fixture_functions.get_current_test_name()
+
+
+def mark_base_test_class():
+ return _fixture_functions.mark_base_test_class()
+
+
+class Config(object):
+ def __init__(self, db, db_opts, options, file_config):
+ self._set_name(db)
+ self.db = db
+ self.db_opts = db_opts
+ self.options = options
+ self.file_config = file_config
+ self.test_schema = "test_schema"
+ self.test_schema_2 = "test_schema_2"
+
+ self.is_async = db.dialect.is_async and not util.asbool(
+ db.url.query.get("async_fallback", False)
+ )
+
+ _stack = collections.deque()
+ _configs = set()
+
+ def _set_name(self, db):
+ if db.dialect.server_version_info:
+ svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
+ self.name = "%s+%s_[%s]" % (db.name, db.driver, svi)
+ else:
+ self.name = "%s+%s" % (db.name, db.driver)
+
+ @classmethod
+ def register(cls, db, db_opts, options, file_config):
+ """add a config as one of the global configs.
+
+ If there are no configs set up yet, this config also
+ gets set as the "_current".
+ """
+ global any_async
+
+ cfg = Config(db, db_opts, options, file_config)
+
+ # if any backends include an async driver, then ensure
+ # all setup/teardown and tests are wrapped in the maybe_async()
+ # decorator that will set up a greenlet context for async drivers.
+ any_async = any_async or cfg.is_async
+
+ cls._configs.add(cfg)
+ return cfg
+
+ @classmethod
+ def set_as_current(cls, config, namespace):
+ global db, _current, db_url, test_schema, test_schema_2, db_opts
+ _current = config
+ db_url = config.db.url
+ db_opts = config.db_opts
+ test_schema = config.test_schema
+ test_schema_2 = config.test_schema_2
+ namespace.db = db = config.db
+
+ @classmethod
+ def push_engine(cls, db, namespace):
+ assert _current, "Can't push without a default Config set up"
+ cls.push(
+ Config(
+ db, _current.db_opts, _current.options, _current.file_config
+ ),
+ namespace,
+ )
+
+ @classmethod
+ def push(cls, config, namespace):
+ cls._stack.append(_current)
+ cls.set_as_current(config, namespace)
+
+ @classmethod
+ def pop(cls, namespace):
+ if cls._stack:
+ # a failed test w/ -x option can call reset() ahead of time
+ _current = cls._stack[-1]
+ del cls._stack[-1]
+ cls.set_as_current(_current, namespace)
+
+ @classmethod
+ def reset(cls, namespace):
+ if cls._stack:
+ cls.set_as_current(cls._stack[0], namespace)
+ cls._stack.clear()
+
+ @classmethod
+ def all_configs(cls):
+ return cls._configs
+
+ @classmethod
+ def all_dbs(cls):
+ for cfg in cls.all_configs():
+ yield cfg.db
+
+ def skip_test(self, msg):
+ skip_test(msg)
+
+
+def skip_test(msg):
+ raise _fixture_functions.skip_test_exception(msg)
+
+
+def async_test(fn):
+ return _fixture_functions.async_test(fn)
diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py
new file mode 100644
index 0000000..b8be6b9
--- /dev/null
+++ b/lib/sqlalchemy/testing/engines.py
@@ -0,0 +1,465 @@
+# testing/engines.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import collections
+import re
+import warnings
+import weakref
+
+from . import config
+from .util import decorator
+from .util import gc_collect
+from .. import event
+from .. import pool
+from ..util import await_only
+
+
+class ConnectionKiller(object):
+ def __init__(self):
+ self.proxy_refs = weakref.WeakKeyDictionary()
+ self.testing_engines = collections.defaultdict(set)
+ self.dbapi_connections = set()
+
+ def add_pool(self, pool):
+ event.listen(pool, "checkout", self._add_conn)
+ event.listen(pool, "checkin", self._remove_conn)
+ event.listen(pool, "close", self._remove_conn)
+ event.listen(pool, "close_detached", self._remove_conn)
+ # note we are keeping "invalidated" here, as those are still
+ # opened connections we would like to roll back
+
+ def _add_conn(self, dbapi_con, con_record, con_proxy):
+ self.dbapi_connections.add(dbapi_con)
+ self.proxy_refs[con_proxy] = True
+
+ def _remove_conn(self, dbapi_conn, *arg):
+ self.dbapi_connections.discard(dbapi_conn)
+
+ def add_engine(self, engine, scope):
+ self.add_pool(engine.pool)
+
+ assert scope in ("class", "global", "function", "fixture")
+ self.testing_engines[scope].add(engine)
+
+ def _safe(self, fn):
+ try:
+ fn()
+ except Exception as e:
+ warnings.warn(
+ "testing_reaper couldn't rollback/close connection: %s" % e
+ )
+
+ def rollback_all(self):
+ for rec in list(self.proxy_refs):
+ if rec is not None and rec.is_valid:
+ self._safe(rec.rollback)
+
+ def checkin_all(self):
+ # run pool.checkin() for all ConnectionFairy instances we have
+ # tracked.
+
+ for rec in list(self.proxy_refs):
+ if rec is not None and rec.is_valid:
+ self.dbapi_connections.discard(rec.dbapi_connection)
+ self._safe(rec._checkin)
+
+ # for fairy refs that were GCed and could not close the connection,
+ # such as asyncio, roll back those remaining connections
+ for con in self.dbapi_connections:
+ self._safe(con.rollback)
+ self.dbapi_connections.clear()
+
+ def close_all(self):
+ self.checkin_all()
+
+ def prepare_for_drop_tables(self, connection):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ from . import provision
+
+ provision.prepare_for_drop_tables(connection.engine.url, connection)
+
+ def _drop_testing_engines(self, scope):
+ eng = self.testing_engines[scope]
+ for rec in list(eng):
+ for proxy_ref in list(self.proxy_refs):
+ if proxy_ref is not None and proxy_ref.is_valid:
+ if (
+ proxy_ref._pool is not None
+ and proxy_ref._pool is rec.pool
+ ):
+ self._safe(proxy_ref._checkin)
+ if hasattr(rec, "sync_engine"):
+ await_only(rec.dispose())
+ else:
+ rec.dispose()
+ eng.clear()
+
+ def after_test(self):
+ self._drop_testing_engines("function")
+
+ def after_test_outside_fixtures(self, test):
+ # don't do aggressive checks for third party test suites
+ if not config.bootstrapped_as_sqlalchemy:
+ return
+
+ if test.__class__.__leave_connections_for_teardown__:
+ return
+
+ self.checkin_all()
+
+ # on PostgreSQL, this will test for any "idle in transaction"
+ # connections. useful to identify tests with unusual patterns
+ # that can't be cleaned up correctly.
+ from . import provision
+
+ with config.db.connect() as conn:
+ provision.prepare_for_drop_tables(conn.engine.url, conn)
+
+ def stop_test_class_inside_fixtures(self):
+ self.checkin_all()
+ self._drop_testing_engines("function")
+ self._drop_testing_engines("class")
+
+ def stop_test_class_outside_fixtures(self):
+ # ensure no refs to checked out connections at all.
+
+ if pool.base._strong_ref_connection_records:
+ gc_collect()
+
+ if pool.base._strong_ref_connection_records:
+ ln = len(pool.base._strong_ref_connection_records)
+ pool.base._strong_ref_connection_records.clear()
+ assert (
+ False
+ ), "%d connection recs not cleared after test suite" % (ln)
+
+ def final_cleanup(self):
+ self.checkin_all()
+ for scope in self.testing_engines:
+ self._drop_testing_engines(scope)
+
+ def assert_all_closed(self):
+ for rec in self.proxy_refs:
+ if rec.is_valid:
+ assert False
+
+
+testing_reaper = ConnectionKiller()
+
+
+@decorator
+def assert_conns_closed(fn, *args, **kw):
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.assert_all_closed()
+
+
+@decorator
+def rollback_open_connections(fn, *args, **kw):
+ """Decorator that rolls back all open connections after fn execution."""
+
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.rollback_all()
+
+
+@decorator
+def close_first(fn, *args, **kw):
+ """Decorator that closes all connections before fn execution."""
+
+ testing_reaper.checkin_all()
+ fn(*args, **kw)
+
+
+@decorator
+def close_open_connections(fn, *args, **kw):
+ """Decorator that closes all connections after fn execution."""
+ try:
+ fn(*args, **kw)
+ finally:
+ testing_reaper.checkin_all()
+
+
+def all_dialects(exclude=None):
+ import sqlalchemy.dialects as d
+
+ for name in d.__all__:
+ # TEMPORARY
+ if exclude and name in exclude:
+ continue
+ mod = getattr(d, name, None)
+ if not mod:
+ mod = getattr(
+ __import__("sqlalchemy.dialects.%s" % name).dialects, name
+ )
+ yield mod.dialect()
+
+
+class ReconnectFixture(object):
+ def __init__(self, dbapi):
+ self.dbapi = dbapi
+ self.connections = []
+ self.is_stopped = False
+
+ def __getattr__(self, key):
+ return getattr(self.dbapi, key)
+
+ def connect(self, *args, **kwargs):
+
+ conn = self.dbapi.connect(*args, **kwargs)
+ if self.is_stopped:
+ self._safe(conn.close)
+ curs = conn.cursor() # should fail on Oracle etc.
+ # should fail for everything that didn't fail
+ # above, connection is closed
+ curs.execute("select 1")
+ assert False, "simulated connect failure didn't work"
+ else:
+ self.connections.append(conn)
+ return conn
+
+ def _safe(self, fn):
+ try:
+ fn()
+ except Exception as e:
+ warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
+
+ def shutdown(self, stop=False):
+ # TODO: this doesn't cover all cases
+ # as nicely as we'd like, namely MySQLdb.
+ # would need to implement R. Brewer's
+ # proxy server idea to get better
+ # coverage.
+ self.is_stopped = stop
+ for c in list(self.connections):
+ self._safe(c.close)
+ self.connections = []
+
+ def restart(self):
+ self.is_stopped = False
+
+
+def reconnecting_engine(url=None, options=None):
+ url = url or config.db.url
+ dbapi = config.db.dialect.dbapi
+ if not options:
+ options = {}
+ options["module"] = ReconnectFixture(dbapi)
+ engine = testing_engine(url, options)
+ _dispose = engine.dispose
+
+ def dispose():
+ engine.dialect.dbapi.shutdown()
+ engine.dialect.dbapi.is_stopped = False
+ _dispose()
+
+ engine.test_shutdown = engine.dialect.dbapi.shutdown
+ engine.test_restart = engine.dialect.dbapi.restart
+ engine.dispose = dispose
+ return engine
+
+
+def testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+ _sqlite_savepoint=False,
+):
+ """Produce an engine configured by --options with optional overrides."""
+
+ if asyncio:
+ assert not _sqlite_savepoint
+ from sqlalchemy.ext.asyncio import (
+ create_async_engine as create_engine,
+ )
+ elif future or (
+ config.db and config.db._is_future and future is not False
+ ):
+ from sqlalchemy.future import create_engine
+ else:
+ from sqlalchemy import create_engine
+ from sqlalchemy.engine.url import make_url
+
+ if not options:
+ use_reaper = True
+ scope = "function"
+ sqlite_savepoint = False
+ else:
+ use_reaper = options.pop("use_reaper", True)
+ scope = options.pop("scope", "function")
+ sqlite_savepoint = options.pop("sqlite_savepoint", False)
+
+ url = url or config.db.url
+
+ url = make_url(url)
+ if options is None:
+ if config.db is None or url.drivername == config.db.url.drivername:
+ options = config.db_opts
+ else:
+ options = {}
+ elif config.db is not None and url.drivername == config.db.url.drivername:
+ default_opt = config.db_opts.copy()
+ default_opt.update(options)
+
+ engine = create_engine(url, **options)
+
+ if sqlite_savepoint and engine.name == "sqlite":
+ # apply SQLite savepoint workaround
+ @event.listens_for(engine, "connect")
+ def do_connect(dbapi_connection, connection_record):
+ dbapi_connection.isolation_level = None
+
+ @event.listens_for(engine, "begin")
+ def do_begin(conn):
+ conn.exec_driver_sql("BEGIN")
+
+ if transfer_staticpool:
+ from sqlalchemy.pool import StaticPool
+
+ if config.db is not None and isinstance(config.db.pool, StaticPool):
+ use_reaper = False
+ engine.pool._transfer_from(config.db.pool)
+
+ if scope == "global":
+ if asyncio:
+ engine.sync_engine._has_events = True
+ else:
+ engine._has_events = (
+ True # enable event blocks, helps with profiling
+ )
+
+ if isinstance(engine.pool, pool.QueuePool):
+ engine.pool._timeout = 0
+ engine.pool._max_overflow = 0
+ if use_reaper:
+ testing_reaper.add_engine(engine, scope)
+
+ return engine
+
+
+def mock_engine(dialect_name=None):
+ """Provides a mocking engine based on the current testing.db.
+
+ This is normally used to test DDL generation flow as emitted
+ by an Engine.
+
+ It should not be used in other cases, as assert_compile() and
+ assert_sql_execution() are much better choices with fewer
+ moving parts.
+
+ """
+
+ from sqlalchemy import create_mock_engine
+
+ if not dialect_name:
+ dialect_name = config.db.name
+
+ buffer = []
+
+ def executor(sql, *a, **kw):
+ buffer.append(sql)
+
+ def assert_sql(stmts):
+ recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
+ assert recv == stmts, recv
+
+ def print_sql():
+ d = engine.dialect
+ return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
+
+ engine = create_mock_engine(dialect_name + "://", executor)
+ assert not hasattr(engine, "mock")
+ engine.mock = buffer
+ engine.assert_sql = assert_sql
+ engine.print_sql = print_sql
+ return engine
+
+
+class DBAPIProxyCursor(object):
+ """Proxy a DBAPI cursor.
+
+ Tests can provide subclasses of this to intercept
+ DBAPI-level cursor operations.
+
+ """
+
+ def __init__(self, engine, conn, *args, **kwargs):
+ self.engine = engine
+ self.connection = conn
+ self.cursor = conn.cursor(*args, **kwargs)
+
+ def execute(self, stmt, parameters=None, **kw):
+ if parameters:
+ return self.cursor.execute(stmt, parameters, **kw)
+ else:
+ return self.cursor.execute(stmt, **kw)
+
+ def executemany(self, stmt, params, **kw):
+ return self.cursor.executemany(stmt, params, **kw)
+
+ def __iter__(self):
+ return iter(self.cursor)
+
+ def __getattr__(self, key):
+ return getattr(self.cursor, key)
+
+
+class DBAPIProxyConnection(object):
+ """Proxy a DBAPI connection.
+
+ Tests can provide subclasses of this to intercept
+ DBAPI-level connection operations.
+
+ """
+
+ def __init__(self, engine, cursor_cls):
+ self.conn = engine.pool._creator()
+ self.engine = engine
+ self.cursor_cls = cursor_cls
+
+ def cursor(self, *args, **kwargs):
+ return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
+
+ def close(self):
+ self.conn.close()
+
+ def __getattr__(self, key):
+ return getattr(self.conn, key)
+
+
+def proxying_engine(
+ conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
+):
+ """Produce an engine that provides proxy hooks for
+ common methods.
+
+ """
+
+ def mock_conn():
+ return conn_cls(config.db, cursor_cls)
+
+ def _wrap_do_on_connect(do_on_connect):
+ def go(dbapi_conn):
+ return do_on_connect(dbapi_conn.conn)
+
+ return go
+
+ return testing_engine(
+ options={
+ "creator": mock_conn,
+ "_wrap_do_on_connect": _wrap_do_on_connect,
+ }
+ )
diff --git a/lib/sqlalchemy/testing/entities.py b/lib/sqlalchemy/testing/entities.py
new file mode 100644
index 0000000..8ea65d6
--- /dev/null
+++ b/lib/sqlalchemy/testing/entities.py
@@ -0,0 +1,111 @@
+# testing/entities.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import sqlalchemy as sa
+from .. import exc as sa_exc
+from ..util import compat
+
+_repr_stack = set()
+
+
+class BasicEntity(object):
+ def __init__(self, **kw):
+ for key, value in kw.items():
+ setattr(self, key, value)
+
+ def __repr__(self):
+ if id(self) in _repr_stack:
+ return object.__repr__(self)
+ _repr_stack.add(id(self))
+ try:
+ return "%s(%s)" % (
+ (self.__class__.__name__),
+ ", ".join(
+ [
+ "%s=%r" % (key, getattr(self, key))
+ for key in sorted(self.__dict__.keys())
+ if not key.startswith("_")
+ ]
+ ),
+ )
+ finally:
+ _repr_stack.remove(id(self))
+
+
+_recursion_stack = set()
+
+
+class ComparableMixin(object):
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __eq__(self, other):
+ """'Deep, sparse compare.
+
+ Deeply compare two entities, following the non-None attributes of the
+ non-persisted object, if possible.
+
+ """
+ if other is self:
+ return True
+ elif not self.__class__ == other.__class__:
+ return False
+
+ if id(self) in _recursion_stack:
+ return True
+ _recursion_stack.add(id(self))
+
+ try:
+ # pick the entity that's not SA persisted as the source
+ try:
+ self_key = sa.orm.attributes.instance_state(self).key
+ except sa.orm.exc.NO_STATE:
+ self_key = None
+
+ if other is None:
+ a = self
+ b = other
+ elif self_key is not None:
+ a = other
+ b = self
+ else:
+ a = self
+ b = other
+
+ for attr in list(a.__dict__):
+ if attr.startswith("_"):
+ continue
+ value = getattr(a, attr)
+
+ try:
+ # handle lazy loader errors
+ battr = getattr(b, attr)
+ except (AttributeError, sa_exc.UnboundExecutionError):
+ return False
+
+ if hasattr(value, "__iter__") and not isinstance(
+ value, compat.string_types
+ ):
+ if hasattr(value, "__getitem__") and not hasattr(
+ value, "keys"
+ ):
+ if list(value) != list(battr):
+ return False
+ else:
+ if set(value) != set(battr):
+ return False
+ else:
+ if value is not None and value != battr:
+ return False
+ return True
+ finally:
+ _recursion_stack.remove(id(self))
+
+
+class ComparableEntity(ComparableMixin, BasicEntity):
+ def __hash__(self):
+ return hash(self.__class__)
diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py
new file mode 100644
index 0000000..521a4aa
--- /dev/null
+++ b/lib/sqlalchemy/testing/exclusions.py
@@ -0,0 +1,465 @@
+# testing/exclusions.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+
+import contextlib
+import operator
+import re
+import sys
+
+from . import config
+from .. import util
+from ..util import decorator
+from ..util.compat import inspect_getfullargspec
+
+
+def skip_if(predicate, reason=None):
+ rule = compound()
+ pred = _as_predicate(predicate, reason)
+ rule.skips.add(pred)
+ return rule
+
+
+def fails_if(predicate, reason=None):
+ rule = compound()
+ pred = _as_predicate(predicate, reason)
+ rule.fails.add(pred)
+ return rule
+
+
+class compound(object):
+ def __init__(self):
+ self.fails = set()
+ self.skips = set()
+ self.tags = set()
+
+ def __add__(self, other):
+ return self.add(other)
+
+ def as_skips(self):
+ rule = compound()
+ rule.skips.update(self.skips)
+ rule.skips.update(self.fails)
+ rule.tags.update(self.tags)
+ return rule
+
+ def add(self, *others):
+ copy = compound()
+ copy.fails.update(self.fails)
+ copy.skips.update(self.skips)
+ copy.tags.update(self.tags)
+ for other in others:
+ copy.fails.update(other.fails)
+ copy.skips.update(other.skips)
+ copy.tags.update(other.tags)
+ return copy
+
+ def not_(self):
+ copy = compound()
+ copy.fails.update(NotPredicate(fail) for fail in self.fails)
+ copy.skips.update(NotPredicate(skip) for skip in self.skips)
+ copy.tags.update(self.tags)
+ return copy
+
+ @property
+ def enabled(self):
+ return self.enabled_for_config(config._current)
+
+ def enabled_for_config(self, config):
+ for predicate in self.skips.union(self.fails):
+ if predicate(config):
+ return False
+ else:
+ return True
+
+ def matching_config_reasons(self, config):
+ return [
+ predicate._as_string(config)
+ for predicate in self.skips.union(self.fails)
+ if predicate(config)
+ ]
+
+ def include_test(self, include_tags, exclude_tags):
+ return bool(
+ not self.tags.intersection(exclude_tags)
+ and (not include_tags or self.tags.intersection(include_tags))
+ )
+
+ def _extend(self, other):
+ self.skips.update(other.skips)
+ self.fails.update(other.fails)
+ self.tags.update(other.tags)
+
+ def __call__(self, fn):
+ if hasattr(fn, "_sa_exclusion_extend"):
+ fn._sa_exclusion_extend._extend(self)
+ return fn
+
+ @decorator
+ def decorate(fn, *args, **kw):
+ return self._do(config._current, fn, *args, **kw)
+
+ decorated = decorate(fn)
+ decorated._sa_exclusion_extend = self
+ return decorated
+
+ @contextlib.contextmanager
+ def fail_if(self):
+ all_fails = compound()
+ all_fails.fails.update(self.skips.union(self.fails))
+
+ try:
+ yield
+ except Exception as ex:
+ all_fails._expect_failure(config._current, ex)
+ else:
+ all_fails._expect_success(config._current)
+
+ def _do(self, cfg, fn, *args, **kw):
+ for skip in self.skips:
+ if skip(cfg):
+ msg = "'%s' : %s" % (
+ config.get_current_test_name(),
+ skip._as_string(cfg),
+ )
+ config.skip_test(msg)
+
+ try:
+ return_value = fn(*args, **kw)
+ except Exception as ex:
+ self._expect_failure(cfg, ex, name=fn.__name__)
+ else:
+ self._expect_success(cfg, name=fn.__name__)
+ return return_value
+
+ def _expect_failure(self, config, ex, name="block"):
+ for fail in self.fails:
+ if fail(config):
+ if util.py2k:
+ str_ex = unicode(ex).encode( # noqa: F821
+ "utf-8", errors="ignore"
+ )
+ else:
+ str_ex = str(ex)
+ print(
+ (
+ "%s failed as expected (%s): %s "
+ % (name, fail._as_string(config), str_ex)
+ )
+ )
+ break
+ else:
+ util.raise_(ex, with_traceback=sys.exc_info()[2])
+
+ def _expect_success(self, config, name="block"):
+ if not self.fails:
+ return
+
+ for fail in self.fails:
+ if fail(config):
+ raise AssertionError(
+ "Unexpected success for '%s' (%s)"
+ % (
+ name,
+ " and ".join(
+ fail._as_string(config) for fail in self.fails
+ ),
+ )
+ )
+
+
+def requires_tag(tagname):
+ return tags([tagname])
+
+
+def tags(tagnames):
+ comp = compound()
+ comp.tags.update(tagnames)
+ return comp
+
+
+def only_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return skip_if(NotPredicate(predicate), reason)
+
+
+def succeeds_if(predicate, reason=None):
+ predicate = _as_predicate(predicate)
+ return fails_if(NotPredicate(predicate), reason)
+
+
+class Predicate(object):
+ @classmethod
+ def as_predicate(cls, predicate, description=None):
+ if isinstance(predicate, compound):
+ return cls.as_predicate(predicate.enabled_for_config, description)
+ elif isinstance(predicate, Predicate):
+ if description and predicate.description is None:
+ predicate.description = description
+ return predicate
+ elif isinstance(predicate, (list, set)):
+ return OrPredicate(
+ [cls.as_predicate(pred) for pred in predicate], description
+ )
+ elif isinstance(predicate, tuple):
+ return SpecPredicate(*predicate)
+ elif isinstance(predicate, util.string_types):
+ tokens = re.match(
+ r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
+ )
+ if not tokens:
+ raise ValueError(
+ "Couldn't locate DB name in predicate: %r" % predicate
+ )
+ db = tokens.group(1)
+ op = tokens.group(2)
+ spec = (
+ tuple(int(d) for d in tokens.group(3).split("."))
+ if tokens.group(3)
+ else None
+ )
+
+ return SpecPredicate(db, op, spec, description=description)
+ elif callable(predicate):
+ return LambdaPredicate(predicate, description)
+ else:
+ assert False, "unknown predicate type: %s" % predicate
+
+ def _format_description(self, config, negate=False):
+ bool_ = self(config)
+ if negate:
+ bool_ = not negate
+ return self.description % {
+ "driver": config.db.url.get_driver_name()
+ if config
+ else "<no driver>",
+ "database": config.db.url.get_backend_name()
+ if config
+ else "<no database>",
+ "doesnt_support": "doesn't support" if bool_ else "does support",
+ "does_support": "does support" if bool_ else "doesn't support",
+ }
+
+ def _as_string(self, config=None, negate=False):
+ raise NotImplementedError()
+
+
+class BooleanPredicate(Predicate):
+ def __init__(self, value, description=None):
+ self.value = value
+ self.description = description or "boolean %s" % value
+
+ def __call__(self, config):
+ return self.value
+
+ def _as_string(self, config, negate=False):
+ return self._format_description(config, negate=negate)
+
+
+class SpecPredicate(Predicate):
+ def __init__(self, db, op=None, spec=None, description=None):
+ self.db = db
+ self.op = op
+ self.spec = spec
+ self.description = description
+
+ _ops = {
+ "<": operator.lt,
+ ">": operator.gt,
+ "==": operator.eq,
+ "!=": operator.ne,
+ "<=": operator.le,
+ ">=": operator.ge,
+ "in": operator.contains,
+ "between": lambda val, pair: val >= pair[0] and val <= pair[1],
+ }
+
+ def __call__(self, config):
+ if config is None:
+ return False
+
+ engine = config.db
+
+ if "+" in self.db:
+ dialect, driver = self.db.split("+")
+ else:
+ dialect, driver = self.db, None
+
+ if dialect and engine.name != dialect:
+ return False
+ if driver is not None and engine.driver != driver:
+ return False
+
+ if self.op is not None:
+ assert driver is None, "DBAPI version specs not supported yet"
+
+ version = _server_version(engine)
+ oper = (
+ hasattr(self.op, "__call__") and self.op or self._ops[self.op]
+ )
+ return oper(version, self.spec)
+ else:
+ return True
+
+ def _as_string(self, config, negate=False):
+ if self.description is not None:
+ return self._format_description(config)
+ elif self.op is None:
+ if negate:
+ return "not %s" % self.db
+ else:
+ return "%s" % self.db
+ else:
+ if negate:
+ return "not %s %s %s" % (self.db, self.op, self.spec)
+ else:
+ return "%s %s %s" % (self.db, self.op, self.spec)
+
+
+class LambdaPredicate(Predicate):
+ def __init__(self, lambda_, description=None, args=None, kw=None):
+ spec = inspect_getfullargspec(lambda_)
+ if not spec[0]:
+ self.lambda_ = lambda db: lambda_()
+ else:
+ self.lambda_ = lambda_
+ self.args = args or ()
+ self.kw = kw or {}
+ if description:
+ self.description = description
+ elif lambda_.__doc__:
+ self.description = lambda_.__doc__
+ else:
+ self.description = "custom function"
+
+ def __call__(self, config):
+ return self.lambda_(config)
+
+ def _as_string(self, config, negate=False):
+ return self._format_description(config)
+
+
+class NotPredicate(Predicate):
+ def __init__(self, predicate, description=None):
+ self.predicate = predicate
+ self.description = description
+
+ def __call__(self, config):
+ return not self.predicate(config)
+
+ def _as_string(self, config, negate=False):
+ if self.description:
+ return self._format_description(config, not negate)
+ else:
+ return self.predicate._as_string(config, not negate)
+
+
+class OrPredicate(Predicate):
+ def __init__(self, predicates, description=None):
+ self.predicates = predicates
+ self.description = description
+
+ def __call__(self, config):
+ for pred in self.predicates:
+ if pred(config):
+ return True
+ return False
+
+ def _eval_str(self, config, negate=False):
+ if negate:
+ conjunction = " and "
+ else:
+ conjunction = " or "
+ return conjunction.join(
+ p._as_string(config, negate=negate) for p in self.predicates
+ )
+
+ def _negation_str(self, config):
+ if self.description is not None:
+ return "Not " + self._format_description(config)
+ else:
+ return self._eval_str(config, negate=True)
+
+ def _as_string(self, config, negate=False):
+ if negate:
+ return self._negation_str(config)
+ else:
+ if self.description is not None:
+ return self._format_description(config)
+ else:
+ return self._eval_str(config)
+
+
+_as_predicate = Predicate.as_predicate
+
+
+def _is_excluded(db, op, spec):
+ return SpecPredicate(db, op, spec)(config._current)
+
+
+def _server_version(engine):
+ """Return a server_version_info tuple."""
+
+ # force metadata to be retrieved
+ conn = engine.connect()
+ version = getattr(engine.dialect, "server_version_info", None)
+ if version is None:
+ version = ()
+ conn.close()
+ return version
+
+
+def db_spec(*dbs):
+ return OrPredicate([Predicate.as_predicate(db) for db in dbs])
+
+
+def open(): # noqa
+ return skip_if(BooleanPredicate(False, "mark as execute"))
+
+
+def closed():
+ return skip_if(BooleanPredicate(True, "marked as skip"))
+
+
+def fails(reason=None):
+ return fails_if(BooleanPredicate(True, reason or "expected to fail"))
+
+
+@decorator
+def future(fn, *arg):
+ return fails_if(LambdaPredicate(fn), "Future feature")
+
+
+def fails_on(db, reason=None):
+ return fails_if(db, reason)
+
+
+def fails_on_everything_except(*dbs):
+ return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
+
+
+def skip(db, reason=None):
+ return skip_if(db, reason)
+
+
+def only_on(dbs, reason=None):
+ return only_if(
+ OrPredicate(
+ [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
+ )
+ )
+
+
+def exclude(db, op, spec, reason=None):
+ return skip_if(SpecPredicate(db, op, spec), reason)
+
+
+def against(config, *queries):
+ assert queries, "no queries sent!"
+ return OrPredicate([Predicate.as_predicate(query) for query in queries])(
+ config
+ )
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
new file mode 100644
index 0000000..0a2d63b
--- /dev/null
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -0,0 +1,870 @@
+# testing/fixtures.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import contextlib
+import re
+import sys
+
+import sqlalchemy as sa
+from . import assertions
+from . import config
+from . import schema
+from .entities import BasicEntity
+from .entities import ComparableEntity
+from .entities import ComparableMixin # noqa
+from .util import adict
+from .util import drop_all_tables_from_metadata
+from .. import event
+from .. import util
+from ..orm import declarative_base
+from ..orm import registry
+from ..orm.decl_api import DeclarativeMeta
+from ..schema import sort_tables_and_constraints
+
+
+@config.mark_base_test_class()
+class TestBase(object):
+ # A sequence of requirement names matching testing.requires decorators
+ __requires__ = ()
+
+ # A sequence of dialect names to exclude from the test class.
+ __unsupported_on__ = ()
+
+ # If present, test class is only runnable for the *single* specified
+ # dialect. If you need multiple, use __unsupported_on__ and invert.
+ __only_on__ = None
+
+ # A sequence of no-arg callables. If any are True, the entire testcase is
+ # skipped.
+ __skip_if__ = None
+
+ # if True, the testing reaper will not attempt to touch connection
+ # state after a test is completed and before the outer teardown
+ # starts
+ __leave_connections_for_teardown__ = False
+
+ def assert_(self, val, msg=None):
+ assert val, msg
+
+ @config.fixture()
+ def nocache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+ @config.fixture()
+ def connection_no_trans(self):
+ eng = getattr(self, "bind", None) or config.db
+
+ with eng.connect() as conn:
+ yield conn
+
+ @config.fixture()
+ def connection(self):
+ global _connection_fixture_connection
+
+ eng = getattr(self, "bind", None) or config.db
+
+ conn = eng.connect()
+ trans = conn.begin()
+
+ _connection_fixture_connection = conn
+ yield conn
+
+ _connection_fixture_connection = None
+
+ if trans.is_active:
+ trans.rollback()
+ # trans would not be active here if the test is using
+ # the legacy @provide_metadata decorator still, as it will
+ # run a close all connections.
+ conn.close()
+
+ @config.fixture()
+ def close_result_when_finished(self):
+ to_close = []
+ to_consume = []
+
+ def go(result, consume=False):
+ to_close.append(result)
+ if consume:
+ to_consume.append(result)
+
+ yield go
+ for r in to_consume:
+ try:
+ r.all()
+ except:
+ pass
+ for r in to_close:
+ try:
+ r.close()
+ except:
+ pass
+
+ @config.fixture()
+ def registry(self, metadata):
+ reg = registry(metadata=metadata)
+ yield reg
+ reg.dispose()
+
+ @config.fixture
+ def decl_base(self, registry):
+ return registry.generate_base()
+
+ @config.fixture()
+ def future_connection(self, future_engine, connection):
+ # integrate the future_engine and connection fixtures so
+ # that users of the "connection" fixture will get at the
+ # "future" connection
+ yield connection
+
+ @config.fixture()
+ def future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+ @config.fixture()
+ def testing_engine(self):
+ from . import engines
+
+ def gen_testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+ ):
+ if options is None:
+ options = {}
+ options["scope"] = "fixture"
+ return engines.testing_engine(
+ url=url,
+ options=options,
+ future=future,
+ asyncio=asyncio,
+ transfer_staticpool=transfer_staticpool,
+ )
+
+ yield gen_testing_engine
+
+ engines.testing_reaper._drop_testing_engines("fixture")
+
+ @config.fixture()
+ def async_testing_engine(self, testing_engine):
+ def go(**kw):
+ kw["asyncio"] = True
+ return testing_engine(**kw)
+
+ return go
+
+ @config.fixture
+ def fixture_session(self):
+ return fixture_session()
+
+ @config.fixture()
+ def metadata(self, request):
+ """Provide bound MetaData for a single test, dropping afterwards."""
+
+ from ..sql import schema
+
+ metadata = schema.MetaData()
+ request.instance.metadata = metadata
+ yield metadata
+ del request.instance.metadata
+
+ if (
+ _connection_fixture_connection
+ and _connection_fixture_connection.in_transaction()
+ ):
+ trans = _connection_fixture_connection.get_transaction()
+ trans.rollback()
+ with _connection_fixture_connection.begin():
+ drop_all_tables_from_metadata(
+ metadata, _connection_fixture_connection
+ )
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
+
+ @config.fixture(
+ params=[
+ (rollback, second_operation, begin_nested)
+ for rollback in (True, False)
+ for second_operation in ("none", "execute", "begin")
+ for begin_nested in (
+ True,
+ False,
+ )
+ ]
+ )
+ def trans_ctx_manager_fixture(self, request, metadata):
+ rollback, second_operation, begin_nested = request.param
+
+ from sqlalchemy import Table, Column, Integer, func, select
+ from . import eq_
+
+ t = Table("test", metadata, Column("data", Integer))
+ eng = getattr(self, "bind", None) or config.db
+
+ t.create(eng)
+
+ def run_test(subject, trans_on_subject, execute_on_subject):
+ with subject.begin() as trans:
+
+ if begin_nested:
+ if not config.requirements.savepoints.enabled:
+ config.skip_test("savepoints not enabled")
+ if execute_on_subject:
+ nested_trans = subject.begin_nested()
+ else:
+ nested_trans = trans.begin_nested()
+
+ with nested_trans:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ # for nested trans, we always commit/rollback on the
+ # "nested trans" object itself.
+ # only Session(future=False) will affect savepoint
+ # transaction for session.commit/rollback
+
+ if rollback:
+ nested_trans.rollback()
+ else:
+ nested_trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction "
+ "inside context "
+ "manager. Please complete the context "
+ "manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(
+ t.insert(), {"data": 12}
+ )
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ # outside the nested trans block, but still inside the
+ # transaction block, we can run SQL, and it will be
+ # committed
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 14})
+ else:
+ trans.execute(t.insert(), {"data": 14})
+
+ else:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ if trans_on_subject:
+ if rollback:
+ subject.rollback()
+ else:
+ subject.commit()
+ else:
+ if rollback:
+ trans.rollback()
+ else:
+ trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction inside "
+ "context "
+ "manager. Please complete the context manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 12})
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if hasattr(trans, "begin"):
+ trans.begin()
+ else:
+ subject.begin()
+ elif second_operation == "begin_nested":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ expected_committed = 0
+ if begin_nested:
+ # begin_nested variant, we inserted a row after the nested
+ # block
+ expected_committed += 1
+ if not rollback:
+ # not rollback variant, our row inserted in the target
+ # block itself would be committed
+ expected_committed += 1
+
+ if execute_on_subject:
+ eq_(
+ subject.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+ else:
+ with subject.connect() as conn:
+ eq_(
+ conn.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+
+ return run_test
+
+
+_connection_fixture_connection = None
+
+
+@contextlib.contextmanager
+def _push_future_engine(engine):
+
+ from ..future.engine import Engine
+ from sqlalchemy import testing
+
+ facade = Engine._future_facade(engine)
+ config._current.push_engine(facade, testing)
+
+ yield facade
+
+ config._current.pop(testing)
+
+
+class FutureEngineMixin(object):
+ @config.fixture(autouse=True, scope="class")
+ def _push_future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+
+class TablesTest(TestBase):
+
+ # 'once', None
+ run_setup_bind = "once"
+
+ # 'once', 'each', None
+ run_define_tables = "once"
+
+ # 'once', 'each', None
+ run_create_tables = "once"
+
+ # 'once', 'each', None
+ run_inserts = "each"
+
+ # 'each', None
+ run_deletes = "each"
+
+ # 'once', None
+ run_dispose_bind = None
+
+ bind = None
+ _tables_metadata = None
+ tables = None
+ other = None
+ sequences = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ cls._setup_once_tables()
+
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ yield
+
+ self._teardown_each_tables()
+
+ @property
+ def tables_test_metadata(self):
+ return self._tables_metadata
+
+ @classmethod
+ def _init_class(cls):
+ if cls.run_define_tables == "each":
+ if cls.run_create_tables == "once":
+ cls.run_create_tables = "each"
+ assert cls.run_inserts in ("each", None)
+
+ cls.other = adict()
+ cls.tables = adict()
+ cls.sequences = adict()
+
+ cls.bind = cls.setup_bind()
+ cls._tables_metadata = sa.MetaData()
+
+ @classmethod
+ def _setup_once_inserts(cls):
+ if cls.run_inserts == "once":
+ cls._load_fixtures()
+ with cls.bind.begin() as conn:
+ cls.insert_data(conn)
+
+ @classmethod
+ def _setup_once_tables(cls):
+ if cls.run_define_tables == "once":
+ cls.define_tables(cls._tables_metadata)
+ if cls.run_create_tables == "once":
+ cls._tables_metadata.create_all(cls.bind)
+ cls.tables.update(cls._tables_metadata.tables)
+ cls.sequences.update(cls._tables_metadata._sequences)
+
+ def _setup_each_tables(self):
+ if self.run_define_tables == "each":
+ self.define_tables(self._tables_metadata)
+ if self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+ self.tables.update(self._tables_metadata.tables)
+ self.sequences.update(self._tables_metadata._sequences)
+ elif self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+
+ def _setup_each_inserts(self):
+ if self.run_inserts == "each":
+ self._load_fixtures()
+ with self.bind.begin() as conn:
+ self.insert_data(conn)
+
+ def _teardown_each_tables(self):
+ if self.run_define_tables == "each":
+ self.tables.clear()
+ if self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+ self._tables_metadata.clear()
+ elif self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+
+ savepoints = getattr(config.requirements, "savepoints", False)
+ if savepoints:
+ savepoints = savepoints.enabled
+
+ # no need to run deletes if tables are recreated on setup
+ if (
+ self.run_define_tables != "each"
+ and self.run_create_tables != "each"
+ and self.run_deletes == "each"
+ ):
+ with self.bind.begin() as conn:
+ for table in reversed(
+ [
+ t
+ for (t, fks) in sort_tables_and_constraints(
+ self._tables_metadata.tables.values()
+ )
+ if t is not None
+ ]
+ ):
+ try:
+ if savepoints:
+ with conn.begin_nested():
+ conn.execute(table.delete())
+ else:
+ conn.execute(table.delete())
+ except sa.exc.DBAPIError as ex:
+ util.print_(
+ ("Error emptying table %s: %r" % (table, ex)),
+ file=sys.stderr,
+ )
+
+ @classmethod
+ def _teardown_once_metadata_bind(cls):
+ if cls.run_create_tables:
+ drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
+
+ if cls.run_dispose_bind == "once":
+ cls.dispose_bind(cls.bind)
+
+ cls._tables_metadata.bind = None
+
+ if cls.run_setup_bind is not None:
+ cls.bind = None
+
+ @classmethod
+ def setup_bind(cls):
+ return config.db
+
+ @classmethod
+ def dispose_bind(cls, bind):
+ if hasattr(bind, "dispose"):
+ bind.dispose()
+ elif hasattr(bind, "close"):
+ bind.close()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ pass
+
+ @classmethod
+ def fixtures(cls):
+ return {}
+
+ @classmethod
+ def insert_data(cls, connection):
+ pass
+
+ def sql_count_(self, count, fn):
+ self.assert_sql_count(self.bind, fn, count)
+
+ def sql_eq_(self, callable_, statements):
+ self.assert_sql(self.bind, callable_, statements)
+
+ @classmethod
+ def _load_fixtures(cls):
+ """Insert rows as represented by the fixtures() method."""
+ headers, rows = {}, {}
+ for table, data in cls.fixtures().items():
+ if len(data) < 2:
+ continue
+ if isinstance(table, util.string_types):
+ table = cls.tables[table]
+ headers[table] = data[0]
+ rows[table] = data[1:]
+ for table, fks in sort_tables_and_constraints(
+ cls._tables_metadata.tables.values()
+ ):
+ if table is None:
+ continue
+ if table not in headers:
+ continue
+ with cls.bind.begin() as conn:
+ conn.execute(
+ table.insert(),
+ [
+ dict(zip(headers[table], column_values))
+ for column_values in rows[table]
+ ],
+ )
+
+
+class NoCache(object):
+ @config.fixture(autouse=True, scope="function")
+ def _disable_cache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+
+class RemovesEvents(object):
+ @util.memoized_property
+ def _event_fns(self):
+ return set()
+
+ def event_listen(self, target, name, fn, **kw):
+ self._event_fns.add((target, name, fn))
+ event.listen(target, name, fn, **kw)
+
+ @config.fixture(autouse=True, scope="function")
+ def _remove_events(self):
+ yield
+ for key in self._event_fns:
+ event.remove(*key)
+
+
+_fixture_sessions = set()
+
+
+def fixture_session(**kw):
+ kw.setdefault("autoflush", True)
+ kw.setdefault("expire_on_commit", True)
+
+ bind = kw.pop("bind", config.db)
+
+ sess = sa.orm.Session(bind, **kw)
+ _fixture_sessions.add(sess)
+ return sess
+
+
+def _close_all_sessions():
+ # will close all still-referenced sessions
+ sa.orm.session.close_all_sessions()
+ _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+ _close_all_sessions()
+ sa.orm.clear_mappers()
+
+
+def after_test():
+ if _fixture_sessions:
+ _close_all_sessions()
+
+
+class ORMTest(TestBase):
+ pass
+
+
+class MappedTest(TablesTest, assertions.AssertsExecutionResults):
+ # 'once', 'each', None
+ run_setup_classes = "once"
+
+ # 'once', 'each', None
+ run_setup_mappers = "each"
+
+ classes = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ if cls.classes is None:
+ cls.classes = adict()
+
+ cls._setup_once_tables()
+ cls._setup_once_classes()
+ cls._setup_once_mappers()
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_class()
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_classes()
+ self._setup_each_mappers()
+ self._setup_each_inserts()
+
+ yield
+
+ sa.orm.session.close_all_sessions()
+ self._teardown_each_mappers()
+ self._teardown_each_classes()
+ self._teardown_each_tables()
+
+ @classmethod
+ def _teardown_once_class(cls):
+ cls.classes.clear()
+
+ @classmethod
+ def _setup_once_classes(cls):
+ if cls.run_setup_classes == "once":
+ cls._with_register_classes(cls.setup_classes)
+
+ @classmethod
+ def _setup_once_mappers(cls):
+ if cls.run_setup_mappers == "once":
+ cls.mapper_registry, cls.mapper = cls._generate_registry()
+ cls._with_register_classes(cls.setup_mappers)
+
+ def _setup_each_mappers(self):
+ if self.run_setup_mappers != "once":
+ (
+ self.__class__.mapper_registry,
+ self.__class__.mapper,
+ ) = self._generate_registry()
+
+ if self.run_setup_mappers == "each":
+ self._with_register_classes(self.setup_mappers)
+
+ def _setup_each_classes(self):
+ if self.run_setup_classes == "each":
+ self._with_register_classes(self.setup_classes)
+
+ @classmethod
+ def _generate_registry(cls):
+ decl = registry(metadata=cls._tables_metadata)
+ return decl, decl.map_imperatively
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ """Run a setup method, framing the operation with a Base class
+ that will catch new subclasses to be established within
+ the "classes" registry.
+
+ """
+ cls_registry = cls.classes
+
+ assert cls_registry is not None
+
+ class FindFixture(type):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ type.__init__(cls, classname, bases, dict_)
+
+ class _Base(util.with_metaclass(FindFixture, object)):
+ pass
+
+ class Basic(BasicEntity, _Base):
+ pass
+
+ class Comparable(ComparableEntity, _Base):
+ pass
+
+ cls.Basic = Basic
+ cls.Comparable = Comparable
+ fn()
+
+ def _teardown_each_mappers(self):
+ # some tests create mappers in the test bodies
+ # and will define setup_mappers as None -
+ # clear mappers in any case
+ if self.run_setup_mappers != "once":
+ sa.orm.clear_mappers()
+
+ def _teardown_each_classes(self):
+ if self.run_setup_classes != "once":
+ self.classes.clear()
+
+ @classmethod
+ def setup_classes(cls):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ pass
+
+
+class DeclarativeMappedTest(MappedTest):
+ run_setup_classes = "once"
+ run_setup_mappers = "once"
+
+ @classmethod
+ def _setup_once_tables(cls):
+ pass
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ cls_registry = cls.classes
+
+ class FindFixtureDeclarative(DeclarativeMeta):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ DeclarativeMeta.__init__(cls, classname, bases, dict_)
+
+ class DeclarativeBasic(object):
+ __table_cls__ = schema.Table
+
+ _DeclBase = declarative_base(
+ metadata=cls._tables_metadata,
+ metaclass=FindFixtureDeclarative,
+ cls=DeclarativeBasic,
+ )
+
+ cls.DeclarativeBasic = _DeclBase
+
+ # sets up cls.Basic which is helpful for things like composite
+ # classes
+ super(DeclarativeMappedTest, cls)._with_register_classes(fn)
+
+ if cls._tables_metadata.tables and cls.run_create_tables:
+ cls._tables_metadata.create_all(config.db)
+
+
+class ComputedReflectionFixtureTest(TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("computed_columns", "table_reflection")
+
+ regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
+
+ def normalize(self, text):
+ return self.regexp.sub("", text).lower()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ from .. import Integer
+ from .. import testing
+ from ..schema import Column
+ from ..schema import Computed
+ from ..schema import Table
+
+ Table(
+ "computed_default_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_col", Integer, Computed("normal + 42")),
+ Column("with_default", Integer, server_default="42"),
+ )
+
+ t = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal + 42")),
+ )
+
+ if testing.requires.schemas.enabled:
+ t2 = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal / 42")),
+ schema=config.test_schema,
+ )
+
+ if testing.requires.computed_columns_virtual.enabled:
+ t.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal + 2", persisted=False),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal / 2", persisted=False),
+ )
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ t.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal - 42", persisted=True),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal * 42", persisted=True),
+ )
+ )
diff --git a/lib/sqlalchemy/testing/mock.py b/lib/sqlalchemy/testing/mock.py
new file mode 100644
index 0000000..e333c70
--- /dev/null
+++ b/lib/sqlalchemy/testing/mock.py
@@ -0,0 +1,32 @@
+# testing/mock.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Import stub for mock library.
+"""
+from __future__ import absolute_import
+
+from ..util import py3k
+
+
+if py3k:
+ from unittest.mock import MagicMock
+ from unittest.mock import Mock
+ from unittest.mock import call
+ from unittest.mock import patch
+ from unittest.mock import ANY
+else:
+ try:
+ from mock import MagicMock # noqa
+ from mock import Mock # noqa
+ from mock import call # noqa
+ from mock import patch # noqa
+ from mock import ANY # noqa
+ except ImportError:
+ raise ImportError(
+ "SQLAlchemy's test suite requires the "
+ "'mock' library as of 0.8.2."
+ )
diff --git a/lib/sqlalchemy/testing/pickleable.py b/lib/sqlalchemy/testing/pickleable.py
new file mode 100644
index 0000000..f05960c
--- /dev/null
+++ b/lib/sqlalchemy/testing/pickleable.py
@@ -0,0 +1,151 @@
+# testing/pickleable.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Classes used in pickling tests, need to be at the module level for
+unpickling.
+"""
+
+from . import fixtures
+from ..schema import Column
+from ..types import String
+
+
+class User(fixtures.ComparableEntity):
+ pass
+
+
+class Order(fixtures.ComparableEntity):
+ pass
+
+
+class Dingaling(fixtures.ComparableEntity):
+ pass
+
+
+class EmailUser(User):
+ pass
+
+
+class Address(fixtures.ComparableEntity):
+ pass
+
+
+# TODO: these are kind of arbitrary....
+class Child1(fixtures.ComparableEntity):
+ pass
+
+
+class Child2(fixtures.ComparableEntity):
+ pass
+
+
+class Parent(fixtures.ComparableEntity):
+ pass
+
+
+class Screen(object):
+ def __init__(self, obj, parent=None):
+ self.obj = obj
+ self.parent = parent
+
+
+class Mixin(object):
+ email_address = Column(String)
+
+
+class AddressWMixin(Mixin, fixtures.ComparableEntity):
+ pass
+
+
+class Foo(object):
+ def __init__(self, moredata, stuff="im stuff"):
+ self.data = "im data"
+ self.stuff = stuff
+ self.moredata = moredata
+
+ __hash__ = object.__hash__
+
+ def __eq__(self, other):
+ return (
+ other.data == self.data
+ and other.stuff == self.stuff
+ and other.moredata == self.moredata
+ )
+
+
+class Bar(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ __hash__ = object.__hash__
+
+ def __eq__(self, other):
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+ def __str__(self):
+ return "Bar(%d, %d)" % (self.x, self.y)
+
+
+class OldSchool:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __eq__(self, other):
+ return (
+ other.__class__ is self.__class__
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+
+class OldSchoolWithoutCompare:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+
+class BarWithoutCompare(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __str__(self):
+ return "Bar(%d, %d)" % (self.x, self.y)
+
+
+class NotComparable(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ return NotImplemented
+
+ def __ne__(self, other):
+ return NotImplemented
+
+
+class BrokenComparable(object):
+ def __init__(self, data):
+ self.data = data
+
+ def __hash__(self):
+ return id(self)
+
+ def __eq__(self, other):
+ raise NotImplementedError
+
+ def __ne__(self, other):
+ raise NotImplementedError
diff --git a/lib/sqlalchemy/testing/plugin/__init__.py b/lib/sqlalchemy/testing/plugin/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/__init__.py
diff --git a/lib/sqlalchemy/testing/plugin/bootstrap.py b/lib/sqlalchemy/testing/plugin/bootstrap.py
new file mode 100644
index 0000000..6721f48
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/bootstrap.py
@@ -0,0 +1,54 @@
+"""
+Bootstrapper for test framework plugins.
+
+The entire rationale for this system is to get the modules in plugin/
+imported without importing all of the supporting library, so that we can
+set up things for testing before coverage starts.
+
+The rationale for all of plugin/ being *in* the supporting library in the
+first place is so that the testing and plugin suite is available to other
+libraries, mainly external SQLAlchemy and Alembic dialects, to make use
+of the same test environment and standard suites available to
+SQLAlchemy/Alembic themselves without the need to ship/install a separate
+package outside of SQLAlchemy.
+
+
+"""
+
+import os
+import sys
+
+
+bootstrap_file = locals()["bootstrap_file"]
+to_bootstrap = locals()["to_bootstrap"]
+
+
+def load_file_as_module(name):
+ path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
+
+ if sys.version_info >= (3, 5):
+ import importlib.util
+
+ spec = importlib.util.spec_from_file_location(name, path)
+ assert spec is not None
+ assert spec.loader is not None
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ else:
+ import imp
+
+ mod = imp.load_source(name, path)
+
+ return mod
+
+
+if to_bootstrap == "pytest":
+ sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
+ sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
+ if sys.version_info < (3, 0):
+ sys.modules["sqla_reinvent_fixtures"] = load_file_as_module(
+ "reinvent_fixtures_py2k"
+ )
+ sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
+else:
+ raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
diff --git a/lib/sqlalchemy/testing/plugin/plugin_base.py b/lib/sqlalchemy/testing/plugin/plugin_base.py
new file mode 100644
index 0000000..d59564e
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/plugin_base.py
@@ -0,0 +1,789 @@
+# plugin/plugin_base.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Testing extensions.
+
+this module is designed to work as a testing-framework-agnostic library,
+created so that multiple test frameworks can be supported at once
+(mostly so that we can migrate to new ones). The current target
+is pytest.
+
+"""
+
+from __future__ import absolute_import
+
+import abc
+import logging
+import re
+import sys
+
+# flag which indicates we are in the SQLAlchemy testing suite,
+# and not that of Alembic or a third party dialect.
+bootstrapped_as_sqlalchemy = False
+
+log = logging.getLogger("sqlalchemy.testing.plugin_base")
+
+
+py3k = sys.version_info >= (3, 0)
+
+if py3k:
+ import configparser
+
+ ABC = abc.ABC
+else:
+ import ConfigParser as configparser
+ import collections as collections_abc # noqa
+
+ class ABC(object):
+ __metaclass__ = abc.ABCMeta
+
+
+# late imports
+fixtures = None
+engines = None
+exclusions = None
+warnings = None
+profiling = None
+provision = None
+assertions = None
+requirements = None
+config = None
+testing = None
+util = None
+file_config = None
+
+logging = None
+include_tags = set()
+exclude_tags = set()
+options = None
+
+
+def setup_options(make_option):
+ make_option(
+ "--log-info",
+ action="callback",
+ type=str,
+ callback=_log,
+ help="turn on info logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--log-debug",
+ action="callback",
+ type=str,
+ callback=_log,
+ help="turn on debug logging for <LOG> (multiple OK)",
+ )
+ make_option(
+ "--db",
+ action="append",
+ type=str,
+ dest="db",
+ help="Use prefab database uri. Multiple OK, "
+ "first one is run by default.",
+ )
+ make_option(
+ "--dbs",
+ action="callback",
+ zeroarg_callback=_list_dbs,
+ help="List available prefab dbs",
+ )
+ make_option(
+ "--dburi",
+ action="append",
+ type=str,
+ dest="dburi",
+ help="Database uri. Multiple OK, " "first one is run by default.",
+ )
+ make_option(
+ "--dbdriver",
+ action="append",
+ type=str,
+ dest="dbdriver",
+ help="Additional database drivers to include in tests. "
+ "These are linked to the existing database URLs by the "
+ "provisioning system.",
+ )
+ make_option(
+ "--dropfirst",
+ action="store_true",
+ dest="dropfirst",
+ help="Drop all tables in the target database first",
+ )
+ make_option(
+ "--disable-asyncio",
+ action="store_true",
+ help="disable test / fixtures / provisoning running in asyncio",
+ )
+ make_option(
+ "--backend-only",
+ action="store_true",
+ dest="backend_only",
+ help="Run only tests marked with __backend__ or __sparse_backend__",
+ )
+ make_option(
+ "--nomemory",
+ action="store_true",
+ dest="nomemory",
+ help="Don't run memory profiling tests",
+ )
+ make_option(
+ "--notimingintensive",
+ action="store_true",
+ dest="notimingintensive",
+ help="Don't run timing intensive tests",
+ )
+ make_option(
+ "--profile-sort",
+ type=str,
+ default="cumulative",
+ dest="profilesort",
+ help="Type of sort for profiling standard output",
+ )
+ make_option(
+ "--profile-dump",
+ type=str,
+ dest="profiledump",
+ help="Filename where a single profile run will be dumped",
+ )
+ make_option(
+ "--postgresql-templatedb",
+ type=str,
+ help="name of template database to use for PostgreSQL "
+ "CREATE DATABASE (defaults to current database)",
+ )
+ make_option(
+ "--low-connections",
+ action="store_true",
+ dest="low_connections",
+ help="Use a low number of distinct connections - "
+ "i.e. for Oracle TNS",
+ )
+ make_option(
+ "--write-idents",
+ type=str,
+ dest="write_idents",
+ help="write out generated follower idents to <file>, "
+ "when -n<num> is used",
+ )
+ make_option(
+ "--reversetop",
+ action="store_true",
+ dest="reversetop",
+ default=False,
+ help="Use a random-ordering set implementation in the ORM "
+ "(helps reveal dependency issues)",
+ )
+ make_option(
+ "--requirements",
+ action="callback",
+ type=str,
+ callback=_requirements_opt,
+ help="requirements class for testing, overrides setup.cfg",
+ )
+ make_option(
+ "--with-cdecimal",
+ action="store_true",
+ dest="cdecimal",
+ default=False,
+ help="Monkeypatch the cdecimal library into Python 'decimal' "
+ "for all tests",
+ )
+ make_option(
+ "--include-tag",
+ action="callback",
+ callback=_include_tag,
+ type=str,
+ help="Include tests with tag <tag>",
+ )
+ make_option(
+ "--exclude-tag",
+ action="callback",
+ callback=_exclude_tag,
+ type=str,
+ help="Exclude tests with tag <tag>",
+ )
+ make_option(
+ "--write-profiles",
+ action="store_true",
+ dest="write_profiles",
+ default=False,
+ help="Write/update failing profiling data.",
+ )
+ make_option(
+ "--force-write-profiles",
+ action="store_true",
+ dest="force_write_profiles",
+ default=False,
+ help="Unconditionally write/update profiling data.",
+ )
+ make_option(
+ "--dump-pyannotate",
+ type=str,
+ dest="dump_pyannotate",
+ help="Run pyannotate and dump json info to given file",
+ )
+ make_option(
+ "--mypy-extra-test-path",
+ type=str,
+ action="append",
+ default=[],
+ dest="mypy_extra_test_paths",
+ help="Additional test directories to add to the mypy tests. "
+ "This is used only when running mypy tests. Multiple OK",
+ )
+
+
+def configure_follower(follower_ident):
+ """Configure required state for a follower.
+
+ This invokes in the parent process and typically includes
+ database creation.
+
+ """
+ from sqlalchemy.testing import provision
+
+ provision.FOLLOWER_IDENT = follower_ident
+
+
+def memoize_important_follower_config(dict_):
+ """Store important configuration we will need to send to a follower.
+
+ This invokes in the parent process after normal config is set up.
+
+ This is necessary as pytest seems to not be using forking, so we
+ start with nothing in memory, *but* it isn't running our argparse
+ callables, so we have to just copy all of that over.
+
+ """
+ dict_["memoized_config"] = {
+ "include_tags": include_tags,
+ "exclude_tags": exclude_tags,
+ }
+
+
+def restore_important_follower_config(dict_):
+ """Restore important configuration needed by a follower.
+
+ This invokes in the follower process.
+
+ """
+ global include_tags, exclude_tags
+ include_tags.update(dict_["memoized_config"]["include_tags"])
+ exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
+
+
+def read_config():
+ global file_config
+ file_config = configparser.ConfigParser()
+ file_config.read(["setup.cfg", "test.cfg"])
+
+
+def pre_begin(opt):
+ """things to set up early, before coverage might be setup."""
+ global options
+ options = opt
+ for fn in pre_configure:
+ fn(options, file_config)
+
+
+def set_coverage_flag(value):
+ options.has_coverage = value
+
+
+def post_begin():
+ """things to set up later, once we know coverage is running."""
+ # Lazy setup of other options (post coverage)
+ for fn in post_configure:
+ fn(options, file_config)
+
+ # late imports, has to happen after config.
+ global util, fixtures, engines, exclusions, assertions, provision
+ global warnings, profiling, config, testing
+ from sqlalchemy import testing # noqa
+ from sqlalchemy.testing import fixtures, engines, exclusions # noqa
+ from sqlalchemy.testing import assertions, warnings, profiling # noqa
+ from sqlalchemy.testing import config, provision # noqa
+ from sqlalchemy import util # noqa
+
+ warnings.setup_filters()
+
+
+def _log(opt_str, value, parser):
+ global logging
+ if not logging:
+ import logging
+
+ logging.basicConfig()
+
+ if opt_str.endswith("-info"):
+ logging.getLogger(value).setLevel(logging.INFO)
+ elif opt_str.endswith("-debug"):
+ logging.getLogger(value).setLevel(logging.DEBUG)
+
+
+def _list_dbs(*args):
+ print("Available --db options (use --dburi to override)")
+ for macro in sorted(file_config.options("db")):
+ print("%20s\t%s" % (macro, file_config.get("db", macro)))
+ sys.exit(0)
+
+
+def _requirements_opt(opt_str, value, parser):
+ _setup_requirements(value)
+
+
+def _exclude_tag(opt_str, value, parser):
+ exclude_tags.add(value.replace("-", "_"))
+
+
+def _include_tag(opt_str, value, parser):
+ include_tags.add(value.replace("-", "_"))
+
+
+pre_configure = []
+post_configure = []
+
+
+def pre(fn):
+ pre_configure.append(fn)
+ return fn
+
+
+def post(fn):
+ post_configure.append(fn)
+ return fn
+
+
+@pre
+def _setup_options(opt, file_config):
+ global options
+ options = opt
+
+
+@pre
+def _set_nomemory(opt, file_config):
+ if opt.nomemory:
+ exclude_tags.add("memory_intensive")
+
+
+@pre
+def _set_notimingintensive(opt, file_config):
+ if opt.notimingintensive:
+ exclude_tags.add("timing_intensive")
+
+
+@pre
+def _monkeypatch_cdecimal(options, file_config):
+ if options.cdecimal:
+ import cdecimal
+
+ sys.modules["decimal"] = cdecimal
+
+
+@post
+def _init_symbols(options, file_config):
+ from sqlalchemy.testing import config
+
+ config._fixture_functions = _fixture_fn_class()
+
+
+@post
+def _set_disable_asyncio(opt, file_config):
+ if opt.disable_asyncio or not py3k:
+ from sqlalchemy.testing import asyncio
+
+ asyncio.ENABLE_ASYNCIO = False
+
+
+@post
+def _engine_uri(options, file_config):
+
+ from sqlalchemy import testing
+ from sqlalchemy.testing import config
+ from sqlalchemy.testing import provision
+
+ if options.dburi:
+ db_urls = list(options.dburi)
+ else:
+ db_urls = []
+
+ extra_drivers = options.dbdriver or []
+
+ if options.db:
+ for db_token in options.db:
+ for db in re.split(r"[,\s]+", db_token):
+ if db not in file_config.options("db"):
+ raise RuntimeError(
+ "Unknown URI specifier '%s'. "
+ "Specify --dbs for known uris." % db
+ )
+ else:
+ db_urls.append(file_config.get("db", db))
+
+ if not db_urls:
+ db_urls.append(file_config.get("db", "default"))
+
+ config._current = None
+
+ expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
+
+ for db_url in expanded_urls:
+ log.info("Adding database URL: %s", db_url)
+
+ if options.write_idents and provision.FOLLOWER_IDENT:
+ with open(options.write_idents, "a") as file_:
+ file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
+
+ cfg = provision.setup_config(
+ db_url, options, file_config, provision.FOLLOWER_IDENT
+ )
+ if not config._current:
+ cfg.set_as_current(cfg, testing)
+
+
+@post
+def _requirements(options, file_config):
+
+ requirement_cls = file_config.get("sqla_testing", "requirement_cls")
+ _setup_requirements(requirement_cls)
+
+
+def _setup_requirements(argument):
+ from sqlalchemy.testing import config
+ from sqlalchemy import testing
+
+ if config.requirements is not None:
+ return
+
+ modname, clsname = argument.split(":")
+
+ # importlib.import_module() only introduced in 2.7, a little
+ # late
+ mod = __import__(modname)
+ for component in modname.split(".")[1:]:
+ mod = getattr(mod, component)
+ req_cls = getattr(mod, clsname)
+
+ config.requirements = testing.requires = req_cls()
+
+ config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
+
+
+@post
+def _prep_testing_database(options, file_config):
+ from sqlalchemy.testing import config
+
+ if options.dropfirst:
+ from sqlalchemy.testing import provision
+
+ for cfg in config.Config.all_configs():
+ provision.drop_all_schema_objects(cfg, cfg.db)
+
+
+@post
+def _reverse_topological(options, file_config):
+ if options.reversetop:
+ from sqlalchemy.orm.util import randomize_unitofwork
+
+ randomize_unitofwork()
+
+
+@post
+def _post_setup_options(opt, file_config):
+ from sqlalchemy.testing import config
+
+ config.options = options
+ config.file_config = file_config
+
+
+@post
+def _setup_profiling(options, file_config):
+ from sqlalchemy.testing import profiling
+
+ profiling._profile_stats = profiling.ProfileStatsFile(
+ file_config.get("sqla_testing", "profile_file"),
+ sort=options.profilesort,
+ dump=options.profiledump,
+ )
+
+
+def want_class(name, cls):
+ if not issubclass(cls, fixtures.TestBase):
+ return False
+ elif name.startswith("_"):
+ return False
+ elif (
+ config.options.backend_only
+ and not getattr(cls, "__backend__", False)
+ and not getattr(cls, "__sparse_backend__", False)
+ and not getattr(cls, "__only_on__", False)
+ ):
+ return False
+ else:
+ return True
+
+
+def want_method(cls, fn):
+ if not fn.__name__.startswith("test_"):
+ return False
+ elif fn.__module__ is None:
+ return False
+ elif include_tags:
+ return (
+ hasattr(cls, "__tags__")
+ and exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
+ ) or (
+ hasattr(fn, "_sa_exclusion_extend")
+ and fn._sa_exclusion_extend.include_test(
+ include_tags, exclude_tags
+ )
+ )
+ elif exclude_tags and hasattr(cls, "__tags__"):
+ return exclusions.tags(cls.__tags__).include_test(
+ include_tags, exclude_tags
+ )
+ elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
+ return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
+ else:
+ return True
+
+
+def generate_sub_tests(cls, module):
+ if getattr(cls, "__backend__", False) or getattr(
+ cls, "__sparse_backend__", False
+ ):
+ sparse = getattr(cls, "__sparse_backend__", False)
+ for cfg in _possible_configs_for_cls(cls, sparse=sparse):
+ orig_name = cls.__name__
+
+ # we can have special chars in these names except for the
+ # pytest junit plugin, which is tripped up by the brackets
+ # and periods, so sanitize
+
+ alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
+ alpha_name = re.sub(r"_+$", "", alpha_name)
+ name = "%s_%s" % (cls.__name__, alpha_name)
+ subcls = type(
+ name,
+ (cls,),
+ {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
+ )
+ setattr(module, name, subcls)
+ yield subcls
+ else:
+ yield cls
+
+
+def start_test_class_outside_fixtures(cls):
+ _do_skips(cls)
+ _setup_engine(cls)
+
+
+def stop_test_class(cls):
+ # close sessions, immediate connections, etc.
+ fixtures.stop_test_class_inside_fixtures(cls)
+
+ # close outstanding connection pool connections, dispose of
+ # additional engines
+ engines.testing_reaper.stop_test_class_inside_fixtures()
+
+
+def stop_test_class_outside_fixtures(cls):
+ engines.testing_reaper.stop_test_class_outside_fixtures()
+ provision.stop_test_class_outside_fixtures(config, config.db, cls)
+ try:
+ if not options.low_connections:
+ assertions.global_cleanup_assertions()
+ finally:
+ _restore_engine()
+
+
+def _restore_engine():
+ if config._current:
+ config._current.reset(testing)
+
+
+def final_process_cleanup():
+ engines.testing_reaper.final_cleanup()
+ assertions.global_cleanup_assertions()
+ _restore_engine()
+
+
+def _setup_engine(cls):
+ if getattr(cls, "__engine_options__", None):
+ opts = dict(cls.__engine_options__)
+ opts["scope"] = "class"
+ eng = engines.testing_engine(options=opts)
+ config._current.push_engine(eng, testing)
+
+
+def before_test(test, test_module_name, test_class, test_name):
+
+ # format looks like:
+ # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
+
+ name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
+
+ id_ = "%s.%s.%s" % (test_module_name, name, test_name)
+
+ profiling._start_current_test(id_)
+
+
+def after_test(test):
+ fixtures.after_test()
+ engines.testing_reaper.after_test()
+
+
+def after_test_fixtures(test):
+ engines.testing_reaper.after_test_outside_fixtures(test)
+
+
+def _possible_configs_for_cls(cls, reasons=None, sparse=False):
+ all_configs = set(config.Config.all_configs())
+
+ if cls.__unsupported_on__:
+ spec = exclusions.db_spec(*cls.__unsupported_on__)
+ for config_obj in list(all_configs):
+ if spec(config_obj):
+ all_configs.remove(config_obj)
+
+ if getattr(cls, "__only_on__", None):
+ spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
+ for config_obj in list(all_configs):
+ if not spec(config_obj):
+ all_configs.remove(config_obj)
+
+ if getattr(cls, "__only_on_config__", None):
+ all_configs.intersection_update([cls.__only_on_config__])
+
+ if hasattr(cls, "__requires__"):
+ requirements = config.requirements
+ for config_obj in list(all_configs):
+ for requirement in cls.__requires__:
+ check = getattr(requirements, requirement)
+
+ skip_reasons = check.matching_config_reasons(config_obj)
+ if skip_reasons:
+ all_configs.remove(config_obj)
+ if reasons is not None:
+ reasons.extend(skip_reasons)
+ break
+
+ if hasattr(cls, "__prefer_requires__"):
+ non_preferred = set()
+ requirements = config.requirements
+ for config_obj in list(all_configs):
+ for requirement in cls.__prefer_requires__:
+ check = getattr(requirements, requirement)
+
+ if not check.enabled_for_config(config_obj):
+ non_preferred.add(config_obj)
+ if all_configs.difference(non_preferred):
+ all_configs.difference_update(non_preferred)
+
+ if sparse:
+ # pick only one config from each base dialect
+ # sorted so we get the same backend each time selecting the highest
+ # server version info.
+ per_dialect = {}
+ for cfg in reversed(
+ sorted(
+ all_configs,
+ key=lambda cfg: (
+ cfg.db.name,
+ cfg.db.driver,
+ cfg.db.dialect.server_version_info,
+ ),
+ )
+ ):
+ db = cfg.db.name
+ if db not in per_dialect:
+ per_dialect[db] = cfg
+ return per_dialect.values()
+
+ return all_configs
+
+
+def _do_skips(cls):
+ reasons = []
+ all_configs = _possible_configs_for_cls(cls, reasons)
+
+ if getattr(cls, "__skip_if__", False):
+ for c in getattr(cls, "__skip_if__"):
+ if c():
+ config.skip_test(
+ "'%s' skipped by %s" % (cls.__name__, c.__name__)
+ )
+
+ if not all_configs:
+ msg = "'%s' unsupported on any DB implementation %s%s" % (
+ cls.__name__,
+ ", ".join(
+ "'%s(%s)+%s'"
+ % (
+ config_obj.db.name,
+ ".".join(
+ str(dig)
+ for dig in exclusions._server_version(config_obj.db)
+ ),
+ config_obj.db.driver,
+ )
+ for config_obj in config.Config.all_configs()
+ ),
+ ", ".join(reasons),
+ )
+ config.skip_test(msg)
+ elif hasattr(cls, "__prefer_backends__"):
+ non_preferred = set()
+ spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
+ for config_obj in all_configs:
+ if not spec(config_obj):
+ non_preferred.add(config_obj)
+ if all_configs.difference(non_preferred):
+ all_configs.difference_update(non_preferred)
+
+ if config._current not in all_configs:
+ _setup_config(all_configs.pop(), cls)
+
+
+def _setup_config(config_obj, ctx):
+ config._current.push(config_obj, testing)
+
+
+class FixtureFunctions(ABC):
+ @abc.abstractmethod
+ def skip_test_exception(self, *arg, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def combinations(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def param_ident(self, *args, **kw):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def fixture(self, *arg, **kw):
+ raise NotImplementedError()
+
+ def get_current_test_name(self):
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def mark_base_test_class(self):
+ raise NotImplementedError()
+
+
+_fixture_fn_class = None
+
+
+def set_fixture_functions(fixture_fn_class):
+ global _fixture_fn_class
+ _fixture_fn_class = fixture_fn_class
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
new file mode 100644
index 0000000..5a51582
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -0,0 +1,820 @@
+try:
+ # installed by bootstrap.py
+ import sqla_plugin_base as plugin_base
+except ImportError:
+ # assume we're a package, use traditional import
+ from . import plugin_base
+
+import argparse
+import collections
+from functools import update_wrapper
+import inspect
+import itertools
+import operator
+import os
+import re
+import sys
+import uuid
+
+import pytest
+
+
+py2k = sys.version_info < (3, 0)
+if py2k:
+ try:
+ import sqla_reinvent_fixtures as reinvent_fixtures_py2k
+ except ImportError:
+ from . import reinvent_fixtures_py2k
+
+
+def pytest_addoption(parser):
+ group = parser.getgroup("sqlalchemy")
+
+ def make_option(name, **kw):
+ callback_ = kw.pop("callback", None)
+ if callback_:
+
+ class CallableAction(argparse.Action):
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
+ callback_(option_string, values, parser)
+
+ kw["action"] = CallableAction
+
+ zeroarg_callback = kw.pop("zeroarg_callback", None)
+ if zeroarg_callback:
+
+ class CallableAction(argparse.Action):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ default=False,
+ required=False,
+ help=None, # noqa
+ ):
+ super(CallableAction, self).__init__(
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ const=True,
+ default=default,
+ required=required,
+ help=help,
+ )
+
+ def __call__(
+ self, parser, namespace, values, option_string=None
+ ):
+ zeroarg_callback(option_string, values, parser)
+
+ kw["action"] = CallableAction
+
+ group.addoption(name, **kw)
+
+ plugin_base.setup_options(make_option)
+ plugin_base.read_config()
+
+
+def pytest_configure(config):
+ if config.pluginmanager.hasplugin("xdist"):
+ config.pluginmanager.register(XDistHooks())
+
+ if hasattr(config, "workerinput"):
+ plugin_base.restore_important_follower_config(config.workerinput)
+ plugin_base.configure_follower(config.workerinput["follower_ident"])
+ else:
+ if config.option.write_idents and os.path.exists(
+ config.option.write_idents
+ ):
+ os.remove(config.option.write_idents)
+
+ plugin_base.pre_begin(config.option)
+
+ plugin_base.set_coverage_flag(
+ bool(getattr(config.option, "cov_source", False))
+ )
+
+ plugin_base.set_fixture_functions(PytestFixtureFunctions)
+
+ if config.option.dump_pyannotate:
+ global DUMP_PYANNOTATE
+ DUMP_PYANNOTATE = True
+
+
+DUMP_PYANNOTATE = False
+
+
+@pytest.fixture(autouse=True)
+def collect_types_fixture():
+ if DUMP_PYANNOTATE:
+ from pyannotate_runtime import collect_types
+
+ collect_types.start()
+ yield
+ if DUMP_PYANNOTATE:
+ collect_types.stop()
+
+
+def pytest_sessionstart(session):
+ from sqlalchemy.testing import asyncio
+
+ asyncio._assume_async(plugin_base.post_begin)
+
+
+def pytest_sessionfinish(session):
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
+
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ collect_types.dump_stats(session.config.option.dump_pyannotate)
+
+
+def pytest_collection_finish(session):
+ if session.config.option.dump_pyannotate:
+ from pyannotate_runtime import collect_types
+
+ lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
+
+ def _filter(filename):
+ filename = os.path.normpath(os.path.abspath(filename))
+ if "lib/sqlalchemy" not in os.path.commonpath(
+ [filename, lib_sqlalchemy]
+ ):
+ return None
+ if "testing" in filename:
+ return None
+
+ return filename
+
+ collect_types.init_types_collection(filter_filename=_filter)
+
+
+class XDistHooks(object):
+ def pytest_configure_node(self, node):
+ from sqlalchemy.testing import provision
+ from sqlalchemy.testing import asyncio
+
+ # the master for each node fills workerinput dictionary
+ # which pytest-xdist will transfer to the subprocess
+
+ plugin_base.memoize_important_follower_config(node.workerinput)
+
+ node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
+
+ asyncio._maybe_async_provisioning(
+ provision.create_follower_db, node.workerinput["follower_ident"]
+ )
+
+ def pytest_testnodedown(self, node, error):
+ from sqlalchemy.testing import provision
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async_provisioning(
+ provision.drop_follower_db, node.workerinput["follower_ident"]
+ )
+
+
+def pytest_collection_modifyitems(session, config, items):
+
+ # look for all those classes that specify __backend__ and
+ # expand them out into per-database test cases.
+
+ # this is much easier to do within pytest_pycollect_makeitem, however
+ # pytest is iterating through cls.__dict__ as makeitem is
+ # called which causes a "dictionary changed size" error on py3k.
+ # I'd submit a pullreq for them to turn it into a list first, but
+ # it's to suit the rather odd use case here which is that we are adding
+ # new classes to a module on the fly.
+
+ from sqlalchemy.testing import asyncio
+
+ rebuilt_items = collections.defaultdict(
+ lambda: collections.defaultdict(list)
+ )
+
+ items[:] = [
+ item
+ for item in items
+ if item.getparent(pytest.Class) is not None
+ and not item.getparent(pytest.Class).name.startswith("_")
+ ]
+
+ test_classes = set(item.getparent(pytest.Class) for item in items)
+
+ def collect(element):
+ for inst_or_fn in element.collect():
+ if isinstance(inst_or_fn, pytest.Collector):
+ # no yield from in 2.7
+ for el in collect(inst_or_fn):
+ yield el
+ else:
+ yield inst_or_fn
+
+ def setup_test_classes():
+ for test_class in test_classes:
+ for sub_cls in plugin_base.generate_sub_tests(
+ test_class.cls, test_class.module
+ ):
+ if sub_cls is not test_class.cls:
+ per_cls_dict = rebuilt_items[test_class.cls]
+
+ # support pytest 5.4.0 and above pytest.Class.from_parent
+ ctor = getattr(pytest.Class, "from_parent", pytest.Class)
+ module = test_class.getparent(pytest.Module)
+ for fn in collect(
+ ctor(name=sub_cls.__name__, parent=module)
+ ):
+ per_cls_dict[fn.name].append(fn)
+
+ # class requirements will sometimes need to access the DB to check
+ # capabilities, so need to do this for async
+ asyncio._maybe_async_provisioning(setup_test_classes)
+
+ newitems = []
+ for item in items:
+ cls_ = item.cls
+ if cls_ in rebuilt_items:
+ newitems.extend(rebuilt_items[cls_][item.name])
+ else:
+ newitems.append(item)
+
+ if py2k:
+ for item in newitems:
+ reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
+
+ # seems like the functions attached to a test class aren't sorted already?
+ # is that true and why's that? (when using unittest, they're sorted)
+ items[:] = sorted(
+ newitems,
+ key=lambda item: (
+ item.getparent(pytest.Module).name,
+ item.getparent(pytest.Class).name,
+ item.name,
+ ),
+ )
+
+
+def pytest_pycollect_makeitem(collector, name, obj):
+ if inspect.isclass(obj) and plugin_base.want_class(name, obj):
+ from sqlalchemy.testing import config
+
+ if config.any_async:
+ obj = _apply_maybe_async(obj)
+
+ ctor = getattr(pytest.Class, "from_parent", pytest.Class)
+ return [
+ ctor(name=parametrize_cls.__name__, parent=collector)
+ for parametrize_cls in _parametrize_cls(collector.module, obj)
+ ]
+ elif (
+ inspect.isfunction(obj)
+ and collector.cls is not None
+ and plugin_base.want_method(collector.cls, obj)
+ ):
+ # None means, fall back to default logic, which includes
+ # method-level parametrize
+ return None
+ else:
+ # empty list means skip this item
+ return []
+
+
+def _is_wrapped_coroutine_function(fn):
+ while hasattr(fn, "__wrapped__"):
+ fn = fn.__wrapped__
+
+ return inspect.iscoroutinefunction(fn)
+
+
+def _apply_maybe_async(obj, recurse=True):
+ from sqlalchemy.testing import asyncio
+
+ for name, value in vars(obj).items():
+ if (
+ (callable(value) or isinstance(value, classmethod))
+ and not getattr(value, "_maybe_async_applied", False)
+ and (name.startswith("test_"))
+ and not _is_wrapped_coroutine_function(value)
+ ):
+ is_classmethod = False
+ if isinstance(value, classmethod):
+ value = value.__func__
+ is_classmethod = True
+
+ @_pytest_fn_decorator
+ def make_async(fn, *args, **kwargs):
+ return asyncio._maybe_async(fn, *args, **kwargs)
+
+ do_async = make_async(value)
+ if is_classmethod:
+ do_async = classmethod(do_async)
+ do_async._maybe_async_applied = True
+
+ setattr(obj, name, do_async)
+ if recurse:
+ for cls in obj.mro()[1:]:
+ if cls != object:
+ _apply_maybe_async(cls, False)
+ return obj
+
+
+def _parametrize_cls(module, cls):
+ """implement a class-based version of pytest parametrize."""
+
+ if "_sa_parametrize" not in cls.__dict__:
+ return [cls]
+
+ _sa_parametrize = cls._sa_parametrize
+ classes = []
+ for full_param_set in itertools.product(
+ *[params for argname, params in _sa_parametrize]
+ ):
+ cls_variables = {}
+
+ for argname, param in zip(
+ [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
+ ):
+ if not argname:
+ raise TypeError("need argnames for class-based combinations")
+ argname_split = re.split(r",\s*", argname)
+ for arg, val in zip(argname_split, param.values):
+ cls_variables[arg] = val
+ parametrized_name = "_".join(
+ # token is a string, but in py2k pytest is giving us a unicode,
+ # so call str() on it.
+ str(re.sub(r"\W", "", token))
+ for param in full_param_set
+ for token in param.id.split("-")
+ )
+ name = "%s_%s" % (cls.__name__, parametrized_name)
+ newcls = type.__new__(type, name, (cls,), cls_variables)
+ setattr(module, name, newcls)
+ classes.append(newcls)
+ return classes
+
+
+_current_class = None
+
+
+def pytest_runtest_setup(item):
+ from sqlalchemy.testing import asyncio
+
+ # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
+ # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
+ # for the whole class and has to run things that are across all current
+ # databases, so we run this outside of the pytest fixture system altogether
+ # and ensure asyncio greenlet if any engines are async
+
+ global _current_class
+
+ if isinstance(item, pytest.Function) and _current_class is None:
+ asyncio._maybe_async_provisioning(
+ plugin_base.start_test_class_outside_fixtures,
+ item.cls,
+ )
+ _current_class = item.getparent(pytest.Class)
+
+
+@pytest.hookimpl(hookwrapper=True)
+def pytest_runtest_teardown(item, nextitem):
+ # runs inside of pytest function fixture scope
+ # after test function runs
+ from sqlalchemy.testing import asyncio
+ from sqlalchemy.util import string_types
+
+ asyncio._maybe_async(plugin_base.after_test, item)
+
+ yield
+ # this is now after all the fixture teardown have run, the class can be
+ # finalized. Since pytest v7 this finalizer can no longer be added in
+ # pytest_runtest_setup since the class has not yet been setup at that
+ # time.
+ # See https://github.com/pytest-dev/pytest/issues/9343
+ global _current_class, _current_report
+
+ if _current_class is not None and (
+ # last test or a new class
+ nextitem is None
+ or nextitem.getparent(pytest.Class) is not _current_class
+ ):
+ _current_class = None
+
+ try:
+ asyncio._maybe_async_provisioning(
+ plugin_base.stop_test_class_outside_fixtures, item.cls
+ )
+ except Exception as e:
+ # in case of an exception during teardown attach the original
+ # error to the exception message, otherwise it will get lost
+ if _current_report.failed:
+ if not e.args:
+ e.args = (
+ "__Original test failure__:\n"
+ + _current_report.longreprtext,
+ )
+ elif e.args[-1] and isinstance(e.args[-1], string_types):
+ args = list(e.args)
+ args[-1] += (
+ "\n__Original test failure__:\n"
+ + _current_report.longreprtext
+ )
+ e.args = tuple(args)
+ else:
+ e.args += (
+ "__Original test failure__",
+ _current_report.longreprtext,
+ )
+ raise
+ finally:
+ _current_report = None
+
+
+def pytest_runtest_call(item):
+ # runs inside of pytest function fixture scope
+ # before test function runs
+
+ from sqlalchemy.testing import asyncio
+
+ asyncio._maybe_async(
+ plugin_base.before_test,
+ item,
+ item.module.__name__,
+ item.cls,
+ item.name,
+ )
+
+
+_current_report = None
+
+
+def pytest_runtest_logreport(report):
+ global _current_report
+ if report.when == "call":
+ _current_report = report
+
+
+@pytest.fixture(scope="class")
+def setup_class_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ cls = request.cls
+
+ if hasattr(cls, "setup_test_class"):
+ asyncio._maybe_async(cls.setup_test_class)
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_setup(request)
+
+ yield
+
+ if py2k:
+ reinvent_fixtures_py2k.run_class_fixture_teardown(request)
+
+ if hasattr(cls, "teardown_test_class"):
+ asyncio._maybe_async(cls.teardown_test_class)
+
+ asyncio._maybe_async(plugin_base.stop_test_class, cls)
+
+
+@pytest.fixture(scope="function")
+def setup_test_methods(request):
+ from sqlalchemy.testing import asyncio
+
+ # called for each test
+
+ self = request.instance
+
+ # before this fixture runs:
+
+ # 1. function level "autouse" fixtures under py3k (examples: TablesTest
+ # define tables / data, MappedTest define tables / mappers / data)
+
+ # 2. run homegrown function level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_setup(request)
+
+ # 3. run outer xdist-style setup
+ if hasattr(self, "setup_test"):
+ asyncio._maybe_async(self.setup_test)
+
+ # alembic test suite is using setUp and tearDown
+ # xdist methods; support these in the test suite
+ # for the near term
+ if hasattr(self, "setUp"):
+ asyncio._maybe_async(self.setUp)
+
+ # inside the yield:
+ # 4. function level fixtures defined on test functions themselves,
+ # e.g. "connection", "metadata" run next
+
+ # 5. pytest hook pytest_runtest_call then runs
+
+ # 6. test itself runs
+
+ yield
+
+ # yield finishes:
+
+ # 7. function level fixtures defined on test functions
+ # themselves, e.g. "connection" rolls back the transaction, "metadata"
+ # emits drop all
+
+ # 8. pytest hook pytest_runtest_teardown hook runs, this is associated
+ # with fixtures close all sessions, provisioning.stop_test_class(),
+ # engines.testing_reaper -> ensure all connection pool connections
+ # are returned, engines created by testing_engine that aren't the
+ # config engine are disposed
+
+ asyncio._maybe_async(plugin_base.after_test_fixtures, self)
+
+ # 10. run xdist-style teardown
+ if hasattr(self, "tearDown"):
+ asyncio._maybe_async(self.tearDown)
+
+ if hasattr(self, "teardown_test"):
+ asyncio._maybe_async(self.teardown_test)
+
+ # 11. run homegrown function-level "autouse" fixtures under py2k
+ if py2k:
+ reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
+
+ # 12. function level "autouse" fixtures under py3k (examples: TablesTest /
+ # MappedTest delete table data, possibly drop tables and clear mappers
+ # depending on the flags defined by the test class)
+
+
+def getargspec(fn):
+ if sys.version_info.major == 3:
+ return inspect.getfullargspec(fn)
+ else:
+ return inspect.getargspec(fn)
+
+
+def _pytest_fn_decorator(target):
+ """Port of langhelpers.decorator with pytest-specific tricks."""
+
+ from sqlalchemy.util.langhelpers import format_argspec_plus
+ from sqlalchemy.util.compat import inspect_getfullargspec
+
+ def _exec_code_in_env(code, env, fn_name):
+ exec(code, env)
+ return env[fn_name]
+
+ def decorate(fn, add_positional_parameters=()):
+
+ spec = inspect_getfullargspec(fn)
+ if add_positional_parameters:
+ spec.args.extend(add_positional_parameters)
+
+ metadata = dict(
+ __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
+ )
+ metadata.update(format_argspec_plus(spec, grouped=False))
+ code = (
+ """\
+def %(name)s(%(args)s):
+ return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
+"""
+ % metadata
+ )
+ decorated = _exec_code_in_env(
+ code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
+ )
+ if not add_positional_parameters:
+ decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ decorated.__wrapped__ = fn
+ return update_wrapper(decorated, fn)
+ else:
+ # this is the pytest hacky part. don't do a full update wrapper
+ # because pytest is really being sneaky about finding the args
+ # for the wrapped function
+ decorated.__module__ = fn.__module__
+ decorated.__name__ = fn.__name__
+ if hasattr(fn, "pytestmark"):
+ decorated.pytestmark = fn.pytestmark
+ return decorated
+
+ return decorate
+
+
+class PytestFixtureFunctions(plugin_base.FixtureFunctions):
+ def skip_test_exception(self, *arg, **kw):
+ return pytest.skip.Exception(*arg, **kw)
+
+ def mark_base_test_class(self):
+ return pytest.mark.usefixtures(
+ "setup_class_methods", "setup_test_methods"
+ )
+
+ _combination_id_fns = {
+ "i": lambda obj: obj,
+ "r": repr,
+ "s": str,
+ "n": lambda obj: obj.__name__
+ if hasattr(obj, "__name__")
+ else type(obj).__name__,
+ }
+
+ def combinations(self, *arg_sets, **kw):
+ """Facade for pytest.mark.parametrize.
+
+ Automatically derives argument names from the callable which in our
+ case is always a method on a class with positional arguments.
+
+ ids for parameter sets are derived using an optional template.
+
+ """
+ from sqlalchemy.testing import exclusions
+
+ if sys.version_info.major == 3:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
+ arg_sets = list(arg_sets[0])
+ else:
+ if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
+ arg_sets = list(arg_sets[0])
+
+ argnames = kw.pop("argnames", None)
+
+ def _filter_exclusions(args):
+ result = []
+ gathered_exclusions = []
+ for a in args:
+ if isinstance(a, exclusions.compound):
+ gathered_exclusions.append(a)
+ else:
+ result.append(a)
+
+ return result, gathered_exclusions
+
+ id_ = kw.pop("id_", None)
+
+ tobuild_pytest_params = []
+ has_exclusions = False
+ if id_:
+ _combination_id_fns = self._combination_id_fns
+
+ # because itemgetter is not consistent for one argument vs.
+ # multiple, make it multiple in all cases and use a slice
+ # to omit the first argument
+ _arg_getter = operator.itemgetter(
+ 0,
+ *[
+ idx
+ for idx, char in enumerate(id_)
+ if char in ("n", "r", "s", "a")
+ ]
+ )
+ fns = [
+ (operator.itemgetter(idx), _combination_id_fns[char])
+ for idx, char in enumerate(id_)
+ if char in _combination_id_fns
+ ]
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ parameters = _arg_getter(fn_params)[1:]
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (
+ parameters,
+ param_exclusions,
+ "-".join(
+ comb_fn(getter(arg)) for getter, comb_fn in fns
+ ),
+ )
+ )
+
+ else:
+
+ for arg in arg_sets:
+ if not isinstance(arg, tuple):
+ arg = (arg,)
+
+ fn_params, param_exclusions = _filter_exclusions(arg)
+
+ if param_exclusions:
+ has_exclusions = True
+
+ tobuild_pytest_params.append(
+ (fn_params, param_exclusions, None)
+ )
+
+ pytest_params = []
+ for parameters, param_exclusions, id_ in tobuild_pytest_params:
+ if has_exclusions:
+ parameters += (param_exclusions,)
+
+ param = pytest.param(*parameters, id=id_)
+ pytest_params.append(param)
+
+ def decorate(fn):
+ if inspect.isclass(fn):
+ if has_exclusions:
+ raise NotImplementedError(
+ "exclusions not supported for class level combinations"
+ )
+ if "_sa_parametrize" not in fn.__dict__:
+ fn._sa_parametrize = []
+ fn._sa_parametrize.append((argnames, pytest_params))
+ return fn
+ else:
+ if argnames is None:
+ _argnames = getargspec(fn).args[1:]
+ else:
+ _argnames = re.split(r", *", argnames)
+
+ if has_exclusions:
+ _argnames += ["_exclusions"]
+
+ @_pytest_fn_decorator
+ def check_exclusions(fn, *args, **kw):
+ _exclusions = args[-1]
+ if _exclusions:
+ exlu = exclusions.compound().add(*_exclusions)
+ fn = exlu(fn)
+ return fn(*args[0:-1], **kw)
+
+ def process_metadata(spec):
+ spec.args.append("_exclusions")
+
+ fn = check_exclusions(
+ fn, add_positional_parameters=("_exclusions",)
+ )
+
+ return pytest.mark.parametrize(_argnames, pytest_params)(fn)
+
+ return decorate
+
+ def param_ident(self, *parameters):
+ ident = parameters[0]
+ return pytest.param(*parameters[1:], id=ident)
+
+ def fixture(self, *arg, **kw):
+ from sqlalchemy.testing import config
+ from sqlalchemy.testing import asyncio
+
+ # wrapping pytest.fixture function. determine if
+ # decorator was called as @fixture or @fixture().
+ if len(arg) > 0 and callable(arg[0]):
+ # was called as @fixture(), we have the function to wrap.
+ fn = arg[0]
+ arg = arg[1:]
+ else:
+ # was called as @fixture, don't have the function yet.
+ fn = None
+
+ # create a pytest.fixture marker. because the fn is not being
+ # passed, this is always a pytest.FixtureFunctionMarker()
+ # object (or whatever pytest is calling it when you read this)
+ # that is waiting for a function.
+ fixture = pytest.fixture(*arg, **kw)
+
+ # now apply wrappers to the function, including fixture itself
+
+ def wrap(fn):
+ if config.any_async:
+ fn = asyncio._maybe_async_wrapper(fn)
+ # other wrappers may be added here
+
+ if py2k and "autouse" in kw:
+ # py2k workaround for too-slow collection of autouse fixtures
+ # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for
+ # rationale.
+
+ # comment this condition out in order to disable the
+ # py2k workaround entirely.
+ reinvent_fixtures_py2k.add_fixture(fn, fixture)
+ else:
+ # now apply FixtureFunctionMarker
+ fn = fixture(fn)
+
+ return fn
+
+ if fn:
+ return wrap(fn)
+ else:
+ return wrap
+
+ def get_current_test_name(self):
+ return os.environ.get("PYTEST_CURRENT_TEST")
+
+ def async_test(self, fn):
+ from sqlalchemy.testing import asyncio
+
+ @_pytest_fn_decorator
+ def decorate(fn, *args, **kwargs):
+ asyncio._run_coroutine_function(fn, *args, **kwargs)
+
+ return decorate(fn)
diff --git a/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
new file mode 100644
index 0000000..36b6841
--- /dev/null
+++ b/lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
@@ -0,0 +1,112 @@
+"""
+invent a quick version of pytest autouse fixtures as pytest's unacceptably slow
+collection/high memory use in pytest 4.6.11, which is the highest version that
+works in py2k.
+
+by "too-slow" we mean the test suite can't even manage to be collected for a
+single process in less than 70 seconds or so and memory use seems to be very
+high as well. for two or four workers the job just times out after ten
+minutes.
+
+so instead we have invented a very limited form of these fixtures, as our
+current use of "autouse" fixtures are limited to those in fixtures.py.
+
+assumptions for these fixtures:
+
+1. we are only using "function" or "class" scope
+
+2. the functions must be associated with a test class
+
+3. the fixture functions cannot themselves use pytest fixtures
+
+4. the fixture functions must use yield, not return
+
+When py2k support is removed and we can stay on a modern pytest version, this
+can all be removed.
+
+
+"""
+import collections
+
+
+_py2k_fixture_fn_names = collections.defaultdict(set)
+_py2k_class_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+_py2k_function_fixtures = collections.defaultdict(
+ lambda: collections.defaultdict(set)
+)
+
+_py2k_cls_fixture_stack = []
+_py2k_fn_fixture_stack = []
+
+
+def add_fixture(fn, fixture):
+ assert fixture.scope in ("class", "function")
+ _py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope))
+
+
+def scan_for_fixtures_to_use_for_class(item):
+ test_class = item.parent.parent.obj
+
+ for name in _py2k_fixture_fn_names:
+ for fixture_fn, scope in _py2k_fixture_fn_names[name]:
+ meth = getattr(test_class, name, None)
+ if meth and meth.im_func is fixture_fn:
+ for sup in test_class.__mro__:
+ if name in sup.__dict__:
+ if scope == "class":
+ _py2k_class_fixtures[test_class][sup].add(meth)
+ elif scope == "function":
+ _py2k_function_fixtures[test_class][sup].add(meth)
+ break
+ break
+
+
+def run_class_fixture_setup(request):
+
+ cls = request.cls
+ self = cls.__new__(cls)
+
+ fixtures_for_this_class = _py2k_class_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in cls.__mro__:
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_cls_fixture_stack.append(iter_)
+
+
+def run_class_fixture_teardown(request):
+ while _py2k_cls_fixture_stack:
+ iter_ = _py2k_cls_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
+
+
+def run_fn_fixture_setup(request):
+ cls = request.cls
+ self = request.instance
+
+ fixtures_for_this_class = _py2k_function_fixtures.get(cls)
+
+ if fixtures_for_this_class:
+ for sup_ in reversed(cls.__mro__):
+ for fn in fixtures_for_this_class.get(sup_, ()):
+ iter_ = fn(self)
+ next(iter_)
+
+ _py2k_fn_fixture_stack.append(iter_)
+
+
+def run_fn_fixture_teardown(request):
+ while _py2k_fn_fixture_stack:
+ iter_ = _py2k_fn_fixture_stack.pop(-1)
+ try:
+ next(iter_)
+ except StopIteration:
+ pass
diff --git a/lib/sqlalchemy/testing/profiling.py b/lib/sqlalchemy/testing/profiling.py
new file mode 100644
index 0000000..4132630
--- /dev/null
+++ b/lib/sqlalchemy/testing/profiling.py
@@ -0,0 +1,335 @@
+# testing/profiling.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Profiling support for unit and performance tests.
+
+These are special purpose profiling methods which operate
+in a more fine-grained way than nose's profiling plugin.
+
+"""
+
+import collections
+import contextlib
+import os
+import platform
+import pstats
+import re
+import sys
+
+from . import config
+from .util import gc_collect
+from ..util import has_compiled_ext
+
+
+try:
+ import cProfile
+except ImportError:
+ cProfile = None
+
+_profile_stats = None
+"""global ProfileStatsFileInstance.
+
+plugin_base assigns this at the start of all tests.
+
+"""
+
+
+_current_test = None
+"""String id of current test.
+
+plugin_base assigns this at the start of each test using
+_start_current_test.
+
+"""
+
+
+def _start_current_test(id_):
+ global _current_test
+ _current_test = id_
+
+ if _profile_stats.force_write:
+ _profile_stats.reset_count()
+
+
+class ProfileStatsFile(object):
+ """Store per-platform/fn profiling results in a file.
+
+ There was no json module available when this was written, but now
+ the file format which is very deterministically line oriented is kind of
+ handy in any case for diffs and merges.
+
+ """
+
+ def __init__(self, filename, sort="cumulative", dump=None):
+ self.force_write = (
+ config.options is not None and config.options.force_write_profiles
+ )
+ self.write = self.force_write or (
+ config.options is not None and config.options.write_profiles
+ )
+ self.fname = os.path.abspath(filename)
+ self.short_fname = os.path.split(self.fname)[-1]
+ self.data = collections.defaultdict(
+ lambda: collections.defaultdict(dict)
+ )
+ self.dump = dump
+ self.sort = sort
+ self._read()
+ if self.write:
+ # rewrite for the case where features changed,
+ # etc.
+ self._write()
+
+ @property
+ def platform_key(self):
+
+ dbapi_key = config.db.name + "_" + config.db.driver
+
+ if config.db.name == "sqlite" and config.db.dialect._is_url_file_db(
+ config.db.url
+ ):
+ dbapi_key += "_file"
+
+ # keep it at 2.7, 3.1, 3.2, etc. for now.
+ py_version = ".".join([str(v) for v in sys.version_info[0:2]])
+
+ platform_tokens = [
+ platform.machine(),
+ platform.system().lower(),
+ platform.python_implementation().lower(),
+ py_version,
+ dbapi_key,
+ ]
+
+ platform_tokens.append(
+ "nativeunicode"
+ if config.db.dialect.convert_unicode
+ else "dbapiunicode"
+ )
+ _has_cext = has_compiled_ext()
+ platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
+ return "_".join(platform_tokens)
+
+ def has_stats(self):
+ test_key = _current_test
+ return (
+ test_key in self.data and self.platform_key in self.data[test_key]
+ )
+
+ def result(self, callcount):
+ test_key = _current_test
+ per_fn = self.data[test_key]
+ per_platform = per_fn[self.platform_key]
+
+ if "counts" not in per_platform:
+ per_platform["counts"] = counts = []
+ else:
+ counts = per_platform["counts"]
+
+ if "current_count" not in per_platform:
+ per_platform["current_count"] = current_count = 0
+ else:
+ current_count = per_platform["current_count"]
+
+ has_count = len(counts) > current_count
+
+ if not has_count:
+ counts.append(callcount)
+ if self.write:
+ self._write()
+ result = None
+ else:
+ result = per_platform["lineno"], counts[current_count]
+ per_platform["current_count"] += 1
+ return result
+
+ def reset_count(self):
+ test_key = _current_test
+ # since self.data is a defaultdict, don't access a key
+ # if we don't know it's there first.
+ if test_key not in self.data:
+ return
+ per_fn = self.data[test_key]
+ if self.platform_key not in per_fn:
+ return
+ per_platform = per_fn[self.platform_key]
+ if "counts" in per_platform:
+ per_platform["counts"][:] = []
+
+ def replace(self, callcount):
+ test_key = _current_test
+ per_fn = self.data[test_key]
+ per_platform = per_fn[self.platform_key]
+ counts = per_platform["counts"]
+ current_count = per_platform["current_count"]
+ if current_count < len(counts):
+ counts[current_count - 1] = callcount
+ else:
+ counts[-1] = callcount
+ if self.write:
+ self._write()
+
+ def _header(self):
+ return (
+ "# %s\n"
+ "# This file is written out on a per-environment basis.\n"
+ "# For each test in aaa_profiling, the corresponding "
+ "function and \n"
+ "# environment is located within this file. "
+ "If it doesn't exist,\n"
+ "# the test is skipped.\n"
+ "# If a callcount does exist, it is compared "
+ "to what we received. \n"
+ "# assertions are raised if the counts do not match.\n"
+ "# \n"
+ "# To add a new callcount test, apply the function_call_count \n"
+ "# decorator and re-run the tests using the --write-profiles \n"
+ "# option - this file will be rewritten including the new count.\n"
+ "# \n"
+ ) % (self.fname)
+
+ def _read(self):
+ try:
+ profile_f = open(self.fname)
+ except IOError:
+ return
+ for lineno, line in enumerate(profile_f):
+ line = line.strip()
+ if not line or line.startswith("#"):
+ continue
+
+ test_key, platform_key, counts = line.split()
+ per_fn = self.data[test_key]
+ per_platform = per_fn[platform_key]
+ c = [int(count) for count in counts.split(",")]
+ per_platform["counts"] = c
+ per_platform["lineno"] = lineno + 1
+ per_platform["current_count"] = 0
+ profile_f.close()
+
+ def _write(self):
+ print(("Writing profile file %s" % self.fname))
+ profile_f = open(self.fname, "w")
+ profile_f.write(self._header())
+ for test_key in sorted(self.data):
+
+ per_fn = self.data[test_key]
+ profile_f.write("\n# TEST: %s\n\n" % test_key)
+ for platform_key in sorted(per_fn):
+ per_platform = per_fn[platform_key]
+ c = ",".join(str(count) for count in per_platform["counts"])
+ profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
+ profile_f.close()
+
+
+def function_call_count(variance=0.05, times=1, warmup=0):
+ """Assert a target for a test case's function call count.
+
+ The main purpose of this assertion is to detect changes in
+ callcounts for various functions - the actual number is not as important.
+ Callcounts are stored in a file keyed to Python version and OS platform
+ information. This file is generated automatically for new tests,
+ and versioned so that unexpected changes in callcounts will be detected.
+
+ """
+
+ # use signature-rewriting decorator function so that pytest fixtures
+ # still work on py27. In Py3, update_wrapper() alone is good enough,
+ # likely due to the introduction of __signature__.
+
+ from sqlalchemy.util import decorator
+ from sqlalchemy.util import deprecations
+ from sqlalchemy.engine import row
+ from sqlalchemy.testing import mock
+
+ @decorator
+ def wrap(fn, *args, **kw):
+
+ with mock.patch.object(
+ deprecations, "SQLALCHEMY_WARN_20", False
+ ), mock.patch.object(
+ row.LegacyRow, "_default_key_style", row.KEY_OBJECTS_NO_WARN
+ ):
+ for warm in range(warmup):
+ fn(*args, **kw)
+
+ timerange = range(times)
+ with count_functions(variance=variance):
+ for time in timerange:
+ rv = fn(*args, **kw)
+ return rv
+
+ return wrap
+
+
+@contextlib.contextmanager
+def count_functions(variance=0.05):
+ if cProfile is None:
+ raise config._skip_test_exception("cProfile is not installed")
+
+ if not _profile_stats.has_stats() and not _profile_stats.write:
+ config.skip_test(
+ "No profiling stats available on this "
+ "platform for this function. Run tests with "
+ "--write-profiles to add statistics to %s for "
+ "this platform." % _profile_stats.short_fname
+ )
+
+ gc_collect()
+
+ pr = cProfile.Profile()
+ pr.enable()
+ # began = time.time()
+ yield
+ # ended = time.time()
+ pr.disable()
+
+ # s = compat.StringIO()
+ stats = pstats.Stats(pr, stream=sys.stdout)
+
+ # timespent = ended - began
+ callcount = stats.total_calls
+
+ expected = _profile_stats.result(callcount)
+
+ if expected is None:
+ expected_count = None
+ else:
+ line_no, expected_count = expected
+
+ print(("Pstats calls: %d Expected %s" % (callcount, expected_count)))
+ stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort))
+ stats.print_stats()
+ if _profile_stats.dump:
+ base, ext = os.path.splitext(_profile_stats.dump)
+ test_name = _current_test.split(".")[-1]
+ dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile")
+ stats.dump_stats(dumpfile)
+ print("Dumped stats to file %s" % dumpfile)
+ # stats.print_callers()
+ if _profile_stats.force_write:
+ _profile_stats.replace(callcount)
+ elif expected_count:
+ deviance = int(callcount * variance)
+ failed = abs(callcount - expected_count) > deviance
+
+ if failed:
+ if _profile_stats.write:
+ _profile_stats.replace(callcount)
+ else:
+ raise AssertionError(
+ "Adjusted function call count %s not within %s%% "
+ "of expected %s, platform %s. Rerun with "
+ "--write-profiles to "
+ "regenerate this callcount."
+ % (
+ callcount,
+ (variance * 100),
+ expected_count,
+ _profile_stats.platform_key,
+ )
+ )
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
new file mode 100644
index 0000000..90c4d93
--- /dev/null
+++ b/lib/sqlalchemy/testing/provision.py
@@ -0,0 +1,416 @@
+import collections
+import logging
+
+from . import config
+from . import engines
+from . import util
+from .. import exc
+from .. import inspect
+from ..engine import url as sa_url
+from ..sql import ddl
+from ..sql import schema
+from ..util import compat
+
+
+log = logging.getLogger(__name__)
+
+FOLLOWER_IDENT = None
+
+
+class register(object):
+ def __init__(self):
+ self.fns = {}
+
+ @classmethod
+ def init(cls, fn):
+ return register().for_db("*")(fn)
+
+ def for_db(self, *dbnames):
+ def decorate(fn):
+ for dbname in dbnames:
+ self.fns[dbname] = fn
+ return self
+
+ return decorate
+
+ def __call__(self, cfg, *arg):
+ if isinstance(cfg, compat.string_types):
+ url = sa_url.make_url(cfg)
+ elif isinstance(cfg, sa_url.URL):
+ url = cfg
+ else:
+ url = cfg.db.url
+ backend = url.get_backend_name()
+ if backend in self.fns:
+ return self.fns[backend](cfg, *arg)
+ else:
+ return self.fns["*"](cfg, *arg)
+
+
+def create_follower_db(follower_ident):
+ for cfg in _configs_for_db_operation():
+ log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
+ create_db(cfg, cfg.db, follower_ident)
+
+
+def setup_config(db_url, options, file_config, follower_ident):
+ # load the dialect, which should also have it set up its provision
+ # hooks
+
+ dialect = sa_url.make_url(db_url).get_dialect()
+ dialect.load_provisioning()
+
+ if follower_ident:
+ db_url = follower_url_from_main(db_url, follower_ident)
+ db_opts = {}
+ update_db_opts(db_url, db_opts)
+ db_opts["scope"] = "global"
+ eng = engines.testing_engine(db_url, db_opts)
+ post_configure_engine(db_url, eng, follower_ident)
+ eng.connect().close()
+
+ cfg = config.Config.register(eng, db_opts, options, file_config)
+
+ # a symbolic name that tests can use if they need to disambiguate
+ # names across databases
+ if follower_ident:
+ config.ident = follower_ident
+
+ if follower_ident:
+ configure_follower(cfg, follower_ident)
+ return cfg
+
+
+def drop_follower_db(follower_ident):
+ for cfg in _configs_for_db_operation():
+ log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
+ drop_db(cfg, cfg.db, follower_ident)
+
+
+def generate_db_urls(db_urls, extra_drivers):
+ """Generate a set of URLs to test given configured URLs plus additional
+ driver names.
+
+ Given::
+
+ --dburi postgresql://db1 \
+ --dburi postgresql://db2 \
+ --dburi postgresql://db2 \
+ --dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
+
+ Noting that the default postgresql driver is psycopg2, the output
+ would be::
+
+ postgresql+psycopg2://db1
+ postgresql+asyncpg://db1
+ postgresql+psycopg2://db2
+ postgresql+psycopg2://db3
+
+ That is, for the driver in a --dburi, we want to keep that and use that
+ driver for each URL it's part of . For a driver that is only
+ in --dbdrivers, we want to use it just once for one of the URLs.
+ for a driver that is both coming from --dburi as well as --dbdrivers,
+ we want to keep it in that dburi.
+
+ Driver specific query options can be specified by added them to the
+ driver name. For example, to enable the async fallback option for
+ asyncpg::
+
+ --dburi postgresql://db1 \
+ --dbdriver=asyncpg?async_fallback=true
+
+ """
+ urls = set()
+
+ backend_to_driver_we_already_have = collections.defaultdict(set)
+
+ urls_plus_dialects = [
+ (url_obj, url_obj.get_dialect())
+ for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
+ ]
+
+ for url_obj, dialect in urls_plus_dialects:
+ backend_to_driver_we_already_have[dialect.name].add(dialect.driver)
+
+ backend_to_driver_we_need = {}
+
+ for url_obj, dialect in urls_plus_dialects:
+ backend = dialect.name
+ dialect.load_provisioning()
+
+ if backend not in backend_to_driver_we_need:
+ backend_to_driver_we_need[backend] = extra_per_backend = set(
+ extra_drivers
+ ).difference(backend_to_driver_we_already_have[backend])
+ else:
+ extra_per_backend = backend_to_driver_we_need[backend]
+
+ for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
+ if driver_url in urls:
+ continue
+ urls.add(driver_url)
+ yield driver_url
+
+
+def _generate_driver_urls(url, extra_drivers):
+ main_driver = url.get_driver_name()
+ extra_drivers.discard(main_driver)
+
+ url = generate_driver_url(url, main_driver, "")
+ yield str(url)
+
+ for drv in list(extra_drivers):
+
+ if "?" in drv:
+
+ driver_only, query_str = drv.split("?", 1)
+
+ else:
+ driver_only = drv
+ query_str = None
+
+ new_url = generate_driver_url(url, driver_only, query_str)
+ if new_url:
+ extra_drivers.remove(drv)
+
+ yield str(new_url)
+
+
+@register.init
+def generate_driver_url(url, driver, query_str):
+ backend = url.get_backend_name()
+
+ new_url = url.set(
+ drivername="%s+%s" % (backend, driver),
+ )
+ if query_str:
+ new_url = new_url.update_query_string(query_str)
+
+ try:
+ new_url.get_dialect()
+ except exc.NoSuchModuleError:
+ return None
+ else:
+ return new_url
+
+
+def _configs_for_db_operation():
+ hosts = set()
+
+ for cfg in config.Config.all_configs():
+ cfg.db.dispose()
+
+ for cfg in config.Config.all_configs():
+ url = cfg.db.url
+ backend = url.get_backend_name()
+ host_conf = (backend, url.username, url.host, url.database)
+
+ if host_conf not in hosts:
+ yield cfg
+ hosts.add(host_conf)
+
+ for cfg in config.Config.all_configs():
+ cfg.db.dispose()
+
+
+@register.init
+def drop_all_schema_objects_pre_tables(cfg, eng):
+ pass
+
+
+@register.init
+def drop_all_schema_objects_post_tables(cfg, eng):
+ pass
+
+
+def drop_all_schema_objects(cfg, eng):
+
+ drop_all_schema_objects_pre_tables(cfg, eng)
+
+ inspector = inspect(eng)
+ try:
+ view_names = inspector.get_view_names()
+ except NotImplementedError:
+ pass
+ else:
+ with eng.begin() as conn:
+ for vname in view_names:
+ conn.execute(
+ ddl._DropView(schema.Table(vname, schema.MetaData()))
+ )
+
+ if config.requirements.schemas.enabled_for_config(cfg):
+ try:
+ view_names = inspector.get_view_names(schema="test_schema")
+ except NotImplementedError:
+ pass
+ else:
+ with eng.begin() as conn:
+ for vname in view_names:
+ conn.execute(
+ ddl._DropView(
+ schema.Table(
+ vname,
+ schema.MetaData(),
+ schema="test_schema",
+ )
+ )
+ )
+
+ util.drop_all_tables(eng, inspector)
+ if config.requirements.schemas.enabled_for_config(cfg):
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
+ util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
+
+ drop_all_schema_objects_post_tables(cfg, eng)
+
+ if config.requirements.sequences.enabled_for_config(cfg):
+ with eng.begin() as conn:
+ for seq in inspector.get_sequence_names():
+ conn.execute(ddl.DropSequence(schema.Sequence(seq)))
+ if config.requirements.schemas.enabled_for_config(cfg):
+ for schema_name in [cfg.test_schema, cfg.test_schema_2]:
+ for seq in inspector.get_sequence_names(
+ schema=schema_name
+ ):
+ conn.execute(
+ ddl.DropSequence(
+ schema.Sequence(seq, schema=schema_name)
+ )
+ )
+
+
+@register.init
+def create_db(cfg, eng, ident):
+ """Dynamically create a database for testing.
+
+ Used when a test run will employ multiple processes, e.g., when run
+ via `tox` or `pytest -n4`.
+ """
+ raise NotImplementedError(
+ "no DB creation routine for cfg: %s" % (eng.url,)
+ )
+
+
+@register.init
+def drop_db(cfg, eng, ident):
+ """Drop a database that we dynamically created for testing."""
+ raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
+
+
+@register.init
+def update_db_opts(db_url, db_opts):
+ """Set database options (db_opts) for a test database that we created."""
+ pass
+
+
+@register.init
+def post_configure_engine(url, engine, follower_ident):
+ """Perform extra steps after configuring an engine for testing.
+
+ (For the internal dialects, currently only used by sqlite, oracle)
+ """
+ pass
+
+
+@register.init
+def follower_url_from_main(url, ident):
+ """Create a connection URL for a dynamically-created test database.
+
+ :param url: the connection URL specified when the test run was invoked
+ :param ident: the pytest-xdist "worker identifier" to be used as the
+ database name
+ """
+ url = sa_url.make_url(url)
+ return url.set(database=ident)
+
+
+@register.init
+def configure_follower(cfg, ident):
+ """Create dialect-specific config settings for a follower database."""
+ pass
+
+
+@register.init
+def run_reap_dbs(url, ident):
+ """Remove databases that were created during the test process, after the
+ process has ended.
+
+ This is an optional step that is invoked for certain backends that do not
+ reliably release locks on the database as long as a process is still in
+ use. For the internal dialects, this is currently only necessary for
+ mssql and oracle.
+ """
+ pass
+
+
+def reap_dbs(idents_file):
+ log.info("Reaping databases...")
+
+ urls = collections.defaultdict(set)
+ idents = collections.defaultdict(set)
+ dialects = {}
+
+ with open(idents_file) as file_:
+ for line in file_:
+ line = line.strip()
+ db_name, db_url = line.split(" ")
+ url_obj = sa_url.make_url(db_url)
+ if db_name not in dialects:
+ dialects[db_name] = url_obj.get_dialect()
+ dialects[db_name].load_provisioning()
+ url_key = (url_obj.get_backend_name(), url_obj.host)
+ urls[url_key].add(db_url)
+ idents[url_key].add(db_name)
+
+ for url_key in urls:
+ url = list(urls[url_key])[0]
+ ident = idents[url_key]
+ run_reap_dbs(url, ident)
+
+
+@register.init
+def temp_table_keyword_args(cfg, eng):
+ """Specify keyword arguments for creating a temporary Table.
+
+ Dialect-specific implementations of this method will return the
+ kwargs that are passed to the Table method when creating a temporary
+ table for testing, e.g., in the define_temp_tables method of the
+ ComponentReflectionTest class in suite/test_reflection.py
+ """
+ raise NotImplementedError(
+ "no temp table keyword args routine for cfg: %s" % (eng.url,)
+ )
+
+
+@register.init
+def prepare_for_drop_tables(config, connection):
+ pass
+
+
+@register.init
+def stop_test_class_outside_fixtures(config, db, testcls):
+ pass
+
+
+@register.init
+def get_temp_table_name(cfg, eng, base_name):
+ """Specify table name for creating a temporary Table.
+
+ Dialect-specific implementations of this method will return the
+ name to use when creating a temporary table for testing,
+ e.g., in the define_temp_tables method of the
+ ComponentReflectionTest class in suite/test_reflection.py
+
+ Default to just the base name since that's what most dialects will
+ use. The mssql dialect's implementation will need a "#" prepended.
+ """
+ return base_name
+
+
+@register.init
+def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
+ raise NotImplementedError(
+ "backend does not implement a schema name set function: %s"
+ % (cfg.db.url,)
+ )
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
new file mode 100644
index 0000000..857d1fd
--- /dev/null
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -0,0 +1,1518 @@
+# testing/requirements.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+External dialect test suites should subclass SuiteRequirements
+to provide specific inclusion/exclusions.
+
+"""
+
+import platform
+import sys
+
+from . import exclusions
+from . import only_on
+from .. import util
+from ..pool import QueuePool
+
+
+class Requirements(object):
+ pass
+
+
+class SuiteRequirements(Requirements):
+ @property
+ def create_table(self):
+ """target platform can emit basic CreateTable DDL."""
+
+ return exclusions.open()
+
+ @property
+ def drop_table(self):
+ """target platform can emit basic DropTable DDL."""
+
+ return exclusions.open()
+
+ @property
+ def table_ddl_if_exists(self):
+ """target platform supports IF NOT EXISTS / IF EXISTS for tables."""
+
+ return exclusions.closed()
+
+ @property
+ def index_ddl_if_exists(self):
+ """target platform supports IF NOT EXISTS / IF EXISTS for indexes."""
+
+ return exclusions.closed()
+
+ @property
+ def foreign_keys(self):
+ """Target database must support foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def table_value_constructor(self):
+ """Database / dialect supports a query like::
+
+ SELECT * FROM VALUES ( (c1, c2), (c1, c2), ...)
+ AS some_table(col1, col2)
+
+ SQLAlchemy generates this with the :func:`_sql.values` function.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def standard_cursor_sql(self):
+ """Target database passes SQL-92 style statements to cursor.execute()
+ when a statement like select() or insert() is run.
+
+ A very small portion of dialect-level tests will ensure that certain
+ conditions are present in SQL strings, and these tests use very basic
+ SQL that will work on any SQL-like platform in order to assert results.
+
+ It's normally a given for any pep-249 DBAPI that a statement like
+ "SELECT id, name FROM table WHERE some_table.id=5" will work.
+ However, there are dialects that don't actually produce SQL Strings
+ and instead may work with symbolic objects instead, or dialects that
+ aren't working with SQL, so for those this requirement can be marked
+ as excluded.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def on_update_cascade(self):
+ """target database must support ON UPDATE..CASCADE behavior in
+ foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def non_updating_cascade(self):
+ """target database must *not* support ON UPDATE..CASCADE behavior in
+ foreign keys."""
+ return exclusions.closed()
+
+ @property
+ def deferrable_fks(self):
+ return exclusions.closed()
+
+ @property
+ def on_update_or_deferrable_fks(self):
+ # TODO: exclusions should be composable,
+ # somehow only_if([x, y]) isn't working here, negation/conjunctions
+ # getting confused.
+ return exclusions.only_if(
+ lambda: self.on_update_cascade.enabled
+ or self.deferrable_fks.enabled
+ )
+
+ @property
+ def queue_pool(self):
+ """target database is using QueuePool"""
+
+ def go(config):
+ return isinstance(config.db.pool, QueuePool)
+
+ return exclusions.only_if(go)
+
+ @property
+ def self_referential_foreign_keys(self):
+ """Target database must support self-referential foreign keys."""
+
+ return exclusions.open()
+
+ @property
+ def foreign_key_ddl(self):
+ """Target database must support the DDL phrases for FOREIGN KEY."""
+
+ return exclusions.open()
+
+ @property
+ def named_constraints(self):
+ """target database must support names for constraints."""
+
+ return exclusions.open()
+
+ @property
+ def implicitly_named_constraints(self):
+ """target database must apply names to unnamed constraints."""
+
+ return exclusions.open()
+
+ @property
+ def subqueries(self):
+ """Target database must support subqueries."""
+
+ return exclusions.open()
+
+ @property
+ def offset(self):
+ """target database can render OFFSET, or an equivalent, in a
+ SELECT.
+ """
+
+ return exclusions.open()
+
+ @property
+ def bound_limit_offset(self):
+ """target database can render LIMIT and/or OFFSET using a bound
+ parameter
+ """
+
+ return exclusions.open()
+
+ @property
+ def sql_expression_limit_offset(self):
+ """target database can render LIMIT and/or OFFSET with a complete
+ SQL expression, such as one that uses the addition operator.
+ parameter
+ """
+
+ return exclusions.open()
+
+ @property
+ def parens_in_union_contained_select_w_limit_offset(self):
+ """Target database must support parenthesized SELECT in UNION
+ when LIMIT/OFFSET is specifically present.
+
+ E.g. (SELECT ...) UNION (SELECT ..)
+
+ This is known to fail on SQLite.
+
+ """
+ return exclusions.open()
+
+ @property
+ def parens_in_union_contained_select_wo_limit_offset(self):
+ """Target database must support parenthesized SELECT in UNION
+ when OFFSET/LIMIT is specifically not present.
+
+ E.g. (SELECT ... LIMIT ..) UNION (SELECT .. OFFSET ..)
+
+ This is known to fail on SQLite. It also fails on Oracle
+ because without LIMIT/OFFSET, there is currently no step that
+ creates an additional subquery.
+
+ """
+ return exclusions.open()
+
+ @property
+ def boolean_col_expressions(self):
+ """Target database must support boolean expressions as columns"""
+
+ return exclusions.closed()
+
+ @property
+ def nullable_booleans(self):
+ """Target database allows boolean columns to store NULL."""
+
+ return exclusions.open()
+
+ @property
+ def nullsordering(self):
+ """Target backends that support nulls ordering."""
+
+ return exclusions.closed()
+
+ @property
+ def standalone_binds(self):
+ """target database/driver supports bound parameters as column
+ expressions without being in the context of a typed column.
+ """
+ return exclusions.closed()
+
+ @property
+ def standalone_null_binds_whereclause(self):
+ """target database/driver supports bound parameters with NULL in the
+ WHERE clause, in situations where it has to be typed.
+
+ """
+ return exclusions.open()
+
+ @property
+ def intersect(self):
+ """Target database must support INTERSECT or equivalent."""
+ return exclusions.closed()
+
+ @property
+ def except_(self):
+ """Target database must support EXCEPT or equivalent (i.e. MINUS)."""
+ return exclusions.closed()
+
+ @property
+ def window_functions(self):
+ """Target database must support window functions."""
+ return exclusions.closed()
+
+ @property
+ def ctes(self):
+ """Target database supports CTEs"""
+
+ return exclusions.closed()
+
+ @property
+ def ctes_with_update_delete(self):
+ """target database supports CTES that ride on top of a normal UPDATE
+ or DELETE statement which refers to the CTE in a correlated subquery.
+
+ """
+
+ return exclusions.closed()
+
+ @property
+ def ctes_on_dml(self):
+ """target database supports CTES which consist of INSERT, UPDATE
+ or DELETE *within* the CTE, e.g. WITH x AS (UPDATE....)"""
+
+ return exclusions.closed()
+
+ @property
+ def autoincrement_insert(self):
+ """target platform generates new surrogate integer primary key values
+ when insert() is executed, excluding the pk column."""
+
+ return exclusions.open()
+
+ @property
+ def fetch_rows_post_commit(self):
+ """target platform will allow cursor.fetchone() to proceed after a
+ COMMIT.
+
+ Typically this refers to an INSERT statement with RETURNING which
+ is invoked within "autocommit". If the row can be returned
+ after the autocommit, then this rule can be open.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def group_by_complex_expression(self):
+ """target platform supports SQL expressions in GROUP BY
+
+ e.g.
+
+ SELECT x + y AS somelabel FROM table GROUP BY x + y
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def sane_rowcount(self):
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_sane_rowcount,
+ "driver doesn't support 'sane' rowcount",
+ )
+
+ @property
+ def sane_multi_rowcount(self):
+ return exclusions.fails_if(
+ lambda config: not config.db.dialect.supports_sane_multi_rowcount,
+ "driver %(driver)s %(doesnt_support)s 'sane' multi row count",
+ )
+
+ @property
+ def sane_rowcount_w_returning(self):
+ return exclusions.fails_if(
+ lambda config: not (
+ config.db.dialect.supports_sane_rowcount_returning
+ ),
+ "driver doesn't support 'sane' rowcount when returning is on",
+ )
+
+ @property
+ def empty_inserts(self):
+ """target platform supports INSERT with no values, i.e.
+ INSERT DEFAULT VALUES or equivalent."""
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.supports_empty_insert
+ or config.db.dialect.supports_default_values
+ or config.db.dialect.supports_default_metavalue,
+ "empty inserts not supported",
+ )
+
+ @property
+ def empty_inserts_executemany(self):
+ """target platform supports INSERT with no values, i.e.
+ INSERT DEFAULT VALUES or equivalent, within executemany()"""
+
+ return self.empty_inserts
+
+ @property
+ def insert_from_select(self):
+ """target platform supports INSERT from a SELECT."""
+
+ return exclusions.open()
+
+ @property
+ def full_returning(self):
+ """target platform supports RETURNING completely, including
+ multiple rows returned.
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.full_returning,
+ "%(database)s %(does_support)s 'RETURNING of multiple rows'",
+ )
+
+ @property
+ def insert_executemany_returning(self):
+ """target platform supports RETURNING when INSERT is used with
+ executemany(), e.g. multiple parameter sets, indicating
+ as many rows come back as do parameter sets were passed.
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.insert_executemany_returning,
+ "%(database)s %(does_support)s 'RETURNING of "
+ "multiple rows with INSERT executemany'",
+ )
+
+ @property
+ def returning(self):
+ """target platform supports RETURNING for at least one row.
+
+ .. seealso::
+
+ :attr:`.Requirements.full_returning`
+
+ """
+
+ return exclusions.only_if(
+ lambda config: config.db.dialect.implicit_returning,
+ "%(database)s %(does_support)s 'RETURNING of a single row'",
+ )
+
+ @property
+ def tuple_in(self):
+ """Target platform supports the syntax
+ "(x, y) IN ((x1, y1), (x2, y2), ...)"
+ """
+
+ return exclusions.closed()
+
+ @property
+ def tuple_in_w_empty(self):
+ """Target platform tuple IN w/ empty set"""
+ return self.tuple_in
+
+ @property
+ def duplicate_names_in_cursor_description(self):
+ """target platform supports a SELECT statement that has
+ the same name repeated more than once in the columns list."""
+
+ return exclusions.open()
+
+ @property
+ def denormalized_names(self):
+ """Target database must have 'denormalized', i.e.
+ UPPERCASE as case insensitive names."""
+
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.requires_name_normalize,
+ "Backend does not require denormalized names.",
+ )
+
+ @property
+ def multivalues_inserts(self):
+ """target database must support multiple VALUES clauses in an
+ INSERT statement."""
+
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_multivalues_insert,
+ "Backend does not support multirow inserts.",
+ )
+
+ @property
+ def implements_get_lastrowid(self):
+ """target dialect implements the executioncontext.get_lastrowid()
+ method without reliance on RETURNING.
+
+ """
+ return exclusions.open()
+
+ @property
+ def emulated_lastrowid(self):
+ """target dialect retrieves cursor.lastrowid, or fetches
+ from a database-side function after an insert() construct executes,
+ within the get_lastrowid() method.
+
+ Only dialects that "pre-execute", or need RETURNING to get last
+ inserted id, would return closed/fail/skip for this.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def emulated_lastrowid_even_with_sequences(self):
+ """target dialect retrieves cursor.lastrowid or an equivalent
+ after an insert() construct executes, even if the table has a
+ Sequence on it.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def dbapi_lastrowid(self):
+ """target platform includes a 'lastrowid' accessor on the DBAPI
+ cursor object.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def views(self):
+ """Target database must support VIEWs."""
+
+ return exclusions.closed()
+
+ @property
+ def schemas(self):
+ """Target database must support external schemas, and have one
+ named 'test_schema'."""
+
+ return only_on(lambda config: config.db.dialect.supports_schemas)
+
+ @property
+ def cross_schema_fk_reflection(self):
+ """target system must support reflection of inter-schema
+ foreign keys"""
+ return exclusions.closed()
+
+ @property
+ def foreign_key_constraint_name_reflection(self):
+ """Target supports refleciton of FOREIGN KEY constraints and
+ will return the name of the constraint that was used in the
+ "CONSTRAINT <name> FOREIGN KEY" DDL.
+
+ MySQL prior to version 8 and MariaDB prior to version 10.5
+ don't support this.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def implicit_default_schema(self):
+ """target system has a strong concept of 'default' schema that can
+ be referred to implicitly.
+
+ basically, PostgreSQL.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def default_schema_name_switch(self):
+ """target dialect implements provisioning module including
+ set_default_schema_on_connection"""
+
+ return exclusions.closed()
+
+ @property
+ def server_side_cursors(self):
+ """Target dialect must support server side cursors."""
+
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_server_side_cursors],
+ "no server side cursors support",
+ )
+
+ @property
+ def sequences(self):
+ """Target database must support SEQUENCEs."""
+
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.supports_sequences],
+ "no sequence support",
+ )
+
+ @property
+ def no_sequences(self):
+ """the opposite of "sequences", DB does not support sequences at
+ all."""
+
+ return exclusions.NotPredicate(self.sequences)
+
+ @property
+ def sequences_optional(self):
+ """Target database supports sequences, but also optionally
+ as a means of generating new PK values."""
+
+ return exclusions.only_if(
+ [
+ lambda config: config.db.dialect.supports_sequences
+ and config.db.dialect.sequences_optional
+ ],
+ "no sequence support, or sequences not optional",
+ )
+
+ @property
+ def supports_lastrowid(self):
+ """target database / driver supports cursor.lastrowid as a means
+ of retrieving the last inserted primary key value.
+
+ note that if the target DB supports sequences also, this is still
+ assumed to work. This is a new use case brought on by MariaDB 10.3.
+
+ """
+ return exclusions.only_if(
+ [lambda config: config.db.dialect.postfetch_lastrowid]
+ )
+
+ @property
+ def no_lastrowid_support(self):
+ """the opposite of supports_lastrowid"""
+ return exclusions.only_if(
+ [lambda config: not config.db.dialect.postfetch_lastrowid]
+ )
+
+ @property
+ def reflects_pk_names(self):
+ return exclusions.closed()
+
+ @property
+ def table_reflection(self):
+ """target database has general support for table reflection"""
+ return exclusions.open()
+
+ @property
+ def reflect_tables_no_columns(self):
+ """target database supports creation and reflection of tables with no
+ columns, or at least tables that seem to have no columns."""
+
+ return exclusions.closed()
+
+ @property
+ def comment_reflection(self):
+ return exclusions.closed()
+
+ @property
+ def view_column_reflection(self):
+ """target database must support retrieval of the columns in a view,
+ similarly to how a table is inspected.
+
+ This does not include the full CREATE VIEW definition.
+
+ """
+ return self.views
+
+ @property
+ def view_reflection(self):
+ """target database must support inspection of the full CREATE VIEW
+ definition."""
+ return self.views
+
+ @property
+ def schema_reflection(self):
+ return self.schemas
+
+ @property
+ def primary_key_constraint_reflection(self):
+ return exclusions.open()
+
+ @property
+ def foreign_key_constraint_reflection(self):
+ return exclusions.open()
+
+ @property
+ def foreign_key_constraint_option_reflection_ondelete(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_ondelete_restrict(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_ondelete_noaction(self):
+ return exclusions.closed()
+
+ @property
+ def foreign_key_constraint_option_reflection_onupdate(self):
+ return exclusions.closed()
+
+ @property
+ def fk_constraint_option_reflection_onupdate_restrict(self):
+ return exclusions.closed()
+
+ @property
+ def temp_table_reflection(self):
+ return exclusions.open()
+
+ @property
+ def temp_table_reflect_indexes(self):
+ return self.temp_table_reflection
+
+ @property
+ def temp_table_names(self):
+ """target dialect supports listing of temporary table names"""
+ return exclusions.closed()
+
+ @property
+ def temporary_tables(self):
+ """target database supports temporary tables"""
+ return exclusions.open()
+
+ @property
+ def temporary_views(self):
+ """target database supports temporary views"""
+ return exclusions.closed()
+
+ @property
+ def index_reflection(self):
+ return exclusions.open()
+
+ @property
+ def index_reflects_included_columns(self):
+ return exclusions.closed()
+
+ @property
+ def indexes_with_ascdesc(self):
+ """target database supports CREATE INDEX with per-column ASC/DESC."""
+ return exclusions.open()
+
+ @property
+ def indexes_with_expressions(self):
+ """target database supports CREATE INDEX against SQL expressions."""
+ return exclusions.closed()
+
+ @property
+ def unique_constraint_reflection(self):
+ """target dialect supports reflection of unique constraints"""
+ return exclusions.open()
+
+ @property
+ def check_constraint_reflection(self):
+ """target dialect supports reflection of check constraints"""
+ return exclusions.closed()
+
+ @property
+ def duplicate_key_raises_integrity_error(self):
+ """target dialect raises IntegrityError when reporting an INSERT
+ with a primary key violation. (hint: it should)
+
+ """
+ return exclusions.open()
+
+ @property
+ def unbounded_varchar(self):
+ """Target database must support VARCHAR with no length"""
+
+ return exclusions.open()
+
+ @property
+ def unicode_data(self):
+ """Target database/dialect must support Python unicode objects with
+ non-ASCII characters represented, delivered as bound parameters
+ as well as in result rows.
+
+ """
+ return exclusions.open()
+
+ @property
+ def unicode_ddl(self):
+ """Target driver must support some degree of non-ascii symbol
+ names.
+ """
+ return exclusions.closed()
+
+ @property
+ def symbol_names_w_double_quote(self):
+ """Target driver can create tables with a name like 'some " table'"""
+ return exclusions.open()
+
+ @property
+ def datetime_literals(self):
+ """target dialect supports rendering of a date, time, or datetime as a
+ literal string, e.g. via the TypeEngine.literal_processor() method.
+
+ """
+
+ return exclusions.closed()
+
+ @property
+ def datetime(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects."""
+
+ return exclusions.open()
+
+ @property
+ def datetime_timezone(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with tzinfo with DateTime(timezone=True)."""
+
+ return exclusions.closed()
+
+ @property
+ def time_timezone(self):
+ """target dialect supports representation of Python
+ datetime.time() with tzinfo with Time(timezone=True)."""
+
+ return exclusions.closed()
+
+ @property
+ def datetime_implicit_bound(self):
+ """target dialect when given a datetime object will bind it such
+ that the database server knows the object is a datetime, and not
+ a plain string.
+
+ """
+ return exclusions.open()
+
+ @property
+ def datetime_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with microsecond objects."""
+
+ return exclusions.open()
+
+ @property
+ def timestamp_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.datetime() with microsecond objects but only
+ if TIMESTAMP is used."""
+ return exclusions.closed()
+
+ @property
+ def timestamp_microseconds_implicit_bound(self):
+ """target dialect when given a datetime object which also includes
+ a microseconds portion when using the TIMESTAMP data type
+ will bind it such that the database server knows
+ the object is a datetime with microseconds, and not a plain string.
+
+ """
+ return self.timestamp_microseconds
+
+ @property
+ def datetime_historic(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects with historic (pre 1970) values."""
+
+ return exclusions.closed()
+
+ @property
+ def date(self):
+ """target dialect supports representation of Python
+ datetime.date() objects."""
+
+ return exclusions.open()
+
+ @property
+ def date_coerces_from_datetime(self):
+ """target dialect accepts a datetime object as the target
+ of a date column."""
+
+ return exclusions.open()
+
+ @property
+ def date_historic(self):
+ """target dialect supports representation of Python
+ datetime.datetime() objects with historic (pre 1970) values."""
+
+ return exclusions.closed()
+
+ @property
+ def time(self):
+ """target dialect supports representation of Python
+ datetime.time() objects."""
+
+ return exclusions.open()
+
+ @property
+ def time_microseconds(self):
+ """target dialect supports representation of Python
+ datetime.time() with microsecond objects."""
+
+ return exclusions.open()
+
+ @property
+ def binary_comparisons(self):
+ """target database/driver can allow BLOB/BINARY fields to be compared
+ against a bound parameter value.
+ """
+
+ return exclusions.open()
+
+ @property
+ def binary_literals(self):
+ """target backend supports simple binary literals, e.g. an
+ expression like::
+
+ SELECT CAST('foo' AS BINARY)
+
+ Where ``BINARY`` is the type emitted from :class:`.LargeBinary`,
+ e.g. it could be ``BLOB`` or similar.
+
+ Basically fails on Oracle.
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def autocommit(self):
+ """target dialect supports 'AUTOCOMMIT' as an isolation_level"""
+ return exclusions.closed()
+
+ @property
+ def isolation_level(self):
+ """target dialect supports general isolation level settings.
+
+ Note that this requirement, when enabled, also requires that
+ the get_isolation_levels() method be implemented.
+
+ """
+ return exclusions.closed()
+
+ def get_isolation_levels(self, config):
+ """Return a structure of supported isolation levels for the current
+ testing dialect.
+
+ The structure indicates to the testing suite what the expected
+ "default" isolation should be, as well as the other values that
+ are accepted. The dictionary has two keys, "default" and "supported".
+ The "supported" key refers to a list of all supported levels and
+ it should include AUTOCOMMIT if the dialect supports it.
+
+ If the :meth:`.DefaultRequirements.isolation_level` requirement is
+ not open, then this method has no return value.
+
+ E.g.::
+
+ >>> testing.requirements.get_isolation_levels()
+ {
+ "default": "READ_COMMITTED",
+ "supported": [
+ "SERIALIZABLE", "READ UNCOMMITTED",
+ "READ COMMITTED", "REPEATABLE READ",
+ "AUTOCOMMIT"
+ ]
+ }
+ """
+
+ @property
+ def json_type(self):
+ """target platform implements a native JSON type."""
+
+ return exclusions.closed()
+
+ @property
+ def json_array_indexes(self):
+ """target platform supports numeric array indexes
+ within a JSON structure"""
+
+ return self.json_type
+
+ @property
+ def json_index_supplementary_unicode_element(self):
+ return exclusions.open()
+
+ @property
+ def legacy_unconditional_json_extract(self):
+ """Backend has a JSON_EXTRACT or similar function that returns a
+ valid JSON string in all cases.
+
+ Used to test a legacy feature and is not needed.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_general(self):
+ """target backend has general support for moderately high-precision
+ numerics."""
+ return exclusions.open()
+
+ @property
+ def precision_numerics_enotation_small(self):
+ """target backend supports Decimal() objects using E notation
+ to represent very small values."""
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_enotation_large(self):
+ """target backend supports Decimal() objects using E notation
+ to represent very large values."""
+ return exclusions.closed()
+
+ @property
+ def precision_numerics_many_significant_digits(self):
+ """target backend supports values with many digits on both sides,
+ such as 319438950232418390.273596, 87673.594069654243
+
+ """
+ return exclusions.closed()
+
+ @property
+ def cast_precision_numerics_many_significant_digits(self):
+ """same as precision_numerics_many_significant_digits but within the
+ context of a CAST statement (hello MySQL)
+
+ """
+ return self.precision_numerics_many_significant_digits
+
+ @property
+ def implicit_decimal_binds(self):
+ """target backend will return a selected Decimal as a Decimal, not
+ a string.
+
+ e.g.::
+
+ expr = decimal.Decimal("15.7563")
+
+ value = e.scalar(
+ select(literal(expr))
+ )
+
+ assert value == expr
+
+ See :ticket:`4036`
+
+ """
+
+ return exclusions.open()
+
+ @property
+ def nested_aggregates(self):
+ """target database can select an aggregate from a subquery that's
+ also using an aggregate
+
+ """
+ return exclusions.open()
+
+ @property
+ def recursive_fk_cascade(self):
+ """target database must support ON DELETE CASCADE on a self-referential
+ foreign key
+
+ """
+ return exclusions.open()
+
+ @property
+ def precision_numerics_retains_significant_digits(self):
+ """A precision numeric type will return empty significant digits,
+ i.e. a value such as 10.000 will come back in Decimal form with
+ the .000 maintained."""
+
+ return exclusions.closed()
+
+ @property
+ def infinity_floats(self):
+ """The Float type can persist and load float('inf'), float('-inf')."""
+
+ return exclusions.closed()
+
+ @property
+ def precision_generic_float_type(self):
+ """target backend will return native floating point numbers with at
+ least seven decimal places when using the generic Float type.
+
+ """
+ return exclusions.open()
+
+ @property
+ def floats_to_four_decimals(self):
+ """target backend can return a floating-point number with four
+ significant digits (such as 15.7563) accurately
+ (i.e. without FP inaccuracies, such as 15.75629997253418).
+
+ """
+ return exclusions.open()
+
+ @property
+ def fetch_null_from_numeric(self):
+ """target backend doesn't crash when you try to select a NUMERIC
+ value that has a value of NULL.
+
+ Added to support Pyodbc bug #351.
+ """
+
+ return exclusions.open()
+
+ @property
+ def text_type(self):
+ """Target database must support an unbounded Text() "
+ "type such as TEXT or CLOB"""
+
+ return exclusions.open()
+
+ @property
+ def empty_strings_varchar(self):
+ """target database can persist/return an empty string with a
+ varchar.
+
+ """
+ return exclusions.open()
+
+ @property
+ def empty_strings_text(self):
+ """target database can persist/return an empty string with an
+ unbounded text."""
+
+ return exclusions.open()
+
+ @property
+ def expressions_against_unbounded_text(self):
+ """target database supports use of an unbounded textual field in a
+ WHERE clause."""
+
+ return exclusions.open()
+
+ @property
+ def selectone(self):
+ """target driver must support the literal statement 'select 1'"""
+ return exclusions.open()
+
+ @property
+ def savepoints(self):
+ """Target database must support savepoints."""
+
+ return exclusions.closed()
+
+ @property
+ def two_phase_transactions(self):
+ """Target database must support two-phase transactions."""
+
+ return exclusions.closed()
+
+ @property
+ def update_from(self):
+ """Target must support UPDATE..FROM syntax"""
+ return exclusions.closed()
+
+ @property
+ def delete_from(self):
+ """Target must support DELETE FROM..FROM or DELETE..USING syntax"""
+ return exclusions.closed()
+
+ @property
+ def update_where_target_in_subquery(self):
+ """Target must support UPDATE (or DELETE) where the same table is
+ present in a subquery in the WHERE clause.
+
+ This is an ANSI-standard syntax that apparently MySQL can't handle,
+ such as::
+
+ UPDATE documents SET flag=1 WHERE documents.title IN
+ (SELECT max(documents.title) AS title
+ FROM documents GROUP BY documents.user_id
+ )
+
+ """
+ return exclusions.open()
+
+ @property
+ def mod_operator_as_percent_sign(self):
+ """target database must use a plain percent '%' as the 'modulus'
+ operator."""
+ return exclusions.closed()
+
+ @property
+ def percent_schema_names(self):
+ """target backend supports weird identifiers with percent signs
+ in them, e.g. 'some % column'.
+
+ this is a very weird use case but often has problems because of
+ DBAPIs that use python formatting. It's not a critical use
+ case either.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def order_by_col_from_union(self):
+ """target database supports ordering by a column from a SELECT
+ inside of a UNION
+
+ E.g. (SELECT id, ...) UNION (SELECT id, ...) ORDER BY id
+
+ """
+ return exclusions.open()
+
+ @property
+ def order_by_label_with_expression(self):
+ """target backend supports ORDER BY a column label within an
+ expression.
+
+ Basically this::
+
+ select data as foo from test order by foo || 'bar'
+
+ Lots of databases including PostgreSQL don't support this,
+ so this is off by default.
+
+ """
+ return exclusions.closed()
+
+ @property
+ def order_by_collation(self):
+ def check(config):
+ try:
+ self.get_order_by_collation(config)
+ return False
+ except NotImplementedError:
+ return True
+
+ return exclusions.skip_if(check)
+
+ def get_order_by_collation(self, config):
+ raise NotImplementedError()
+
+ @property
+ def unicode_connections(self):
+ """Target driver must support non-ASCII characters being passed at
+ all.
+ """
+ return exclusions.open()
+
+ @property
+ def graceful_disconnects(self):
+ """Target driver must raise a DBAPI-level exception, such as
+ InterfaceError, when the underlying connection has been closed
+ and the execute() method is called.
+ """
+ return exclusions.open()
+
+ @property
+ def independent_connections(self):
+ """
+ Target must support simultaneous, independent database connections.
+ """
+ return exclusions.open()
+
+ @property
+ def skip_mysql_on_windows(self):
+ """Catchall for a large variety of MySQL on Windows failures"""
+ return exclusions.open()
+
+ @property
+ def ad_hoc_engines(self):
+ """Test environment must allow ad-hoc engine/connection creation.
+
+ DBs that scale poorly for many connections, even when closed, i.e.
+ Oracle, may use the "--low-connections" option which flags this
+ requirement as not present.
+
+ """
+ return exclusions.skip_if(
+ lambda config: config.options.low_connections
+ )
+
+ @property
+ def no_windows(self):
+ return exclusions.skip_if(self._running_on_windows())
+
+ def _running_on_windows(self):
+ return exclusions.LambdaPredicate(
+ lambda: platform.system() == "Windows",
+ description="running on Windows",
+ )
+
+ @property
+ def timing_intensive(self):
+ return exclusions.requires_tag("timing_intensive")
+
+ @property
+ def memory_intensive(self):
+ return exclusions.requires_tag("memory_intensive")
+
+ @property
+ def threading_with_mock(self):
+ """Mark tests that use threading and mock at the same time - stability
+ issues have been observed with coverage + python 3.3
+
+ """
+ return exclusions.skip_if(
+ lambda config: util.py3k and config.options.has_coverage,
+ "Stability issues with coverage + py3k",
+ )
+
+ @property
+ def sqlalchemy2_stubs(self):
+ def check(config):
+ try:
+ __import__("sqlalchemy-stubs.ext.mypy")
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(check)
+
+ @property
+ def python2(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info >= (3,),
+ "Python version 2.xx is required.",
+ )
+
+ @property
+ def python3(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3,), "Python version 3.xx is required."
+ )
+
+ @property
+ def pep520(self):
+ return self.python36
+
+ @property
+ def insert_order_dicts(self):
+ return self.python37
+
+ @property
+ def python36(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3, 6),
+ "Python version 3.6 or greater is required.",
+ )
+
+ @property
+ def python37(self):
+ return exclusions.skip_if(
+ lambda: sys.version_info < (3, 7),
+ "Python version 3.7 or greater is required.",
+ )
+
+ @property
+ def dataclasses(self):
+ return self.python37
+
+ @property
+ def python38(self):
+ return exclusions.only_if(
+ lambda: util.py38, "Python 3.8 or above required"
+ )
+
+ @property
+ def cpython(self):
+ return exclusions.only_if(
+ lambda: util.cpython, "cPython interpreter needed"
+ )
+
+ @property
+ def patch_library(self):
+ def check_lib():
+ try:
+ __import__("patch")
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(check_lib, "patch library needed")
+
+ @property
+ def non_broken_pickle(self):
+ from sqlalchemy.util import pickle
+
+ return exclusions.only_if(
+ lambda: util.cpython
+ and pickle.__name__ == "cPickle"
+ or sys.version_info >= (3, 2),
+ "Needs cPickle+cPython or newer Python 3 pickle",
+ )
+
+ @property
+ def predictable_gc(self):
+ """target platform must remove all cycles unconditionally when
+ gc.collect() is called, as well as clean out unreferenced subclasses.
+
+ """
+ return self.cpython
+
+ @property
+ def no_coverage(self):
+ """Test should be skipped if coverage is enabled.
+
+ This is to block tests that exercise libraries that seem to be
+ sensitive to coverage, such as PostgreSQL notice logging.
+
+ """
+ return exclusions.skip_if(
+ lambda config: config.options.has_coverage,
+ "Issues observed when coverage is enabled",
+ )
+
+ def _has_mysql_on_windows(self, config):
+ return False
+
+ def _has_mysql_fully_case_sensitive(self, config):
+ return False
+
+ @property
+ def sqlite(self):
+ return exclusions.skip_if(lambda: not self._has_sqlite())
+
+ @property
+ def cextensions(self):
+ return exclusions.skip_if(
+ lambda: not util.has_compiled_ext(), "C extensions not installed"
+ )
+
+ def _has_sqlite(self):
+ from sqlalchemy import create_engine
+
+ try:
+ create_engine("sqlite://")
+ return True
+ except ImportError:
+ return False
+
+ @property
+ def async_dialect(self):
+ """dialect makes use of await_() to invoke operations on the DBAPI."""
+
+ return exclusions.closed()
+
+ @property
+ def asyncio(self):
+ return self.greenlet
+
+ @property
+ def greenlet(self):
+ def go(config):
+ try:
+ import greenlet # noqa: F401
+ except ImportError:
+ return False
+ else:
+ return True
+
+ return exclusions.only_if(go)
+
+ @property
+ def computed_columns(self):
+ "Supports computed columns"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_stored(self):
+ "Supports computed columns with `persisted=True`"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_virtual(self):
+ "Supports computed columns with `persisted=False`"
+ return exclusions.closed()
+
+ @property
+ def computed_columns_default_persisted(self):
+ """If the default persistence is virtual or stored when `persisted`
+ is omitted"""
+ return exclusions.closed()
+
+ @property
+ def computed_columns_reflect_persisted(self):
+ """If persistence information is returned by the reflection of
+ computed columns"""
+ return exclusions.closed()
+
+ @property
+ def supports_distinct_on(self):
+ """If a backend supports the DISTINCT ON in a select"""
+ return exclusions.closed()
+
+ @property
+ def supports_is_distinct_from(self):
+ """Supports some form of "x IS [NOT] DISTINCT FROM y" construct.
+ Different dialects will implement their own flavour, e.g.,
+ sqlite will emit "x IS NOT y" instead of "x IS DISTINCT FROM y".
+
+ .. seealso::
+
+ :meth:`.ColumnOperators.is_distinct_from`
+
+ """
+ return exclusions.skip_if(
+ lambda config: not config.db.dialect.supports_is_distinct_from,
+ "driver doesn't support an IS DISTINCT FROM construct",
+ )
+
+ @property
+ def identity_columns(self):
+ """If a backend supports GENERATED { ALWAYS | BY DEFAULT }
+ AS IDENTITY"""
+ return exclusions.closed()
+
+ @property
+ def identity_columns_standard(self):
+ """If a backend supports GENERATED { ALWAYS | BY DEFAULT }
+ AS IDENTITY with a standard syntax.
+ This is mainly to exclude MSSql.
+ """
+ return exclusions.closed()
+
+ @property
+ def regexp_match(self):
+ """backend supports the regexp_match operator."""
+ return exclusions.closed()
+
+ @property
+ def regexp_replace(self):
+ """backend supports the regexp_replace operator."""
+ return exclusions.closed()
+
+ @property
+ def fetch_first(self):
+ """backend supports the fetch first clause."""
+ return exclusions.closed()
+
+ @property
+ def fetch_percent(self):
+ """backend supports the fetch first clause with percent."""
+ return exclusions.closed()
+
+ @property
+ def fetch_ties(self):
+ """backend supports the fetch first clause with ties."""
+ return exclusions.closed()
+
+ @property
+ def fetch_no_order_by(self):
+ """backend supports the fetch first without order by"""
+ return exclusions.closed()
+
+ @property
+ def fetch_offset_with_options(self):
+ """backend supports the offset when using fetch first with percent
+ or ties. basically this is "not mssql"
+ """
+ return exclusions.closed()
+
+ @property
+ def fetch_expression(self):
+ """backend supports fetch / offset with expression in them, like
+
+ SELECT * FROM some_table
+ OFFSET 1 + 1 ROWS FETCH FIRST 1 + 1 ROWS ONLY
+ """
+ return exclusions.closed()
+
+ @property
+ def autoincrement_without_sequence(self):
+ """If autoincrement=True on a column does not require an explicit
+ sequence. This should be false only for oracle.
+ """
+ return exclusions.open()
+
+ @property
+ def generic_classes(self):
+ "If X[Y] can be implemented with ``__class_getitem__``. py3.7+"
+ return exclusions.only_if(lambda: util.py37)
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
new file mode 100644
index 0000000..bff07a5
--- /dev/null
+++ b/lib/sqlalchemy/testing/schema.py
@@ -0,0 +1,218 @@
+# testing/schema.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import sys
+
+from . import config
+from . import exclusions
+from .. import event
+from .. import schema
+from .. import types as sqltypes
+from ..util import OrderedDict
+
+
+__all__ = ["Table", "Column"]
+
+table_options = {}
+
+
+def Table(*args, **kw):
+ """A schema.Table wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
+
+ kw.update(table_options)
+
+ if exclusions.against(config._current, "mysql"):
+ if (
+ "mysql_engine" not in kw
+ and "mysql_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mysql_engine"] = "InnoDB"
+ else:
+ kw["mysql_engine"] = "MyISAM"
+ elif exclusions.against(config._current, "mariadb"):
+ if (
+ "mariadb_engine" not in kw
+ and "mariadb_type" not in kw
+ and "autoload_with" not in kw
+ ):
+ if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
+ kw["mariadb_engine"] = "InnoDB"
+ else:
+ kw["mariadb_engine"] = "MyISAM"
+
+ # Apply some default cascading rules for self-referential foreign keys.
+ # MySQL InnoDB has some issues around selecting self-refs too.
+ if exclusions.against(config._current, "firebird"):
+ table_name = args[0]
+ unpack = config.db.dialect.identifier_preparer.unformat_identifiers
+
+ # Only going after ForeignKeys in Columns. May need to
+ # expand to ForeignKeyConstraint too.
+ fks = [
+ fk
+ for col in args
+ if isinstance(col, schema.Column)
+ for fk in col.foreign_keys
+ ]
+
+ for fk in fks:
+ # root around in raw spec
+ ref = fk._colspec
+ if isinstance(ref, schema.Column):
+ name = ref.table.name
+ else:
+ # take just the table name: on FB there cannot be
+ # a schema, so the first element is always the
+ # table name, possibly followed by the field name
+ name = unpack(ref)[0]
+ if name == table_name:
+ if fk.ondelete is None:
+ fk.ondelete = "CASCADE"
+ if fk.onupdate is None:
+ fk.onupdate = "CASCADE"
+
+ return schema.Table(*args, **kw)
+
+
+def Column(*args, **kw):
+ """A schema.Column wrapper/hook for dialect-specific tweaks."""
+
+ test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
+
+ if not config.requirements.foreign_key_ddl.enabled_for_config(config):
+ args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
+
+ col = schema.Column(*args, **kw)
+ if test_opts.get("test_needs_autoincrement", False) and kw.get(
+ "primary_key", False
+ ):
+
+ if col.default is None and col.server_default is None:
+ col.autoincrement = True
+
+ # allow any test suite to pick up on this
+ col.info["test_needs_autoincrement"] = True
+
+ # hardcoded rule for firebird, oracle; this should
+ # be moved out
+ if exclusions.against(config._current, "firebird", "oracle"):
+
+ def add_seq(c, tbl):
+ c._init_items(
+ schema.Sequence(
+ _truncate_name(
+ config.db.dialect, tbl.name + "_" + c.name + "_seq"
+ ),
+ optional=True,
+ )
+ )
+
+ event.listen(col, "after_parent_attach", add_seq, propagate=True)
+ return col
+
+
+class eq_type_affinity(object):
+ """Helper to compare types inside of datastructures based on affinity.
+
+ E.g.::
+
+ eq_(
+ inspect(connection).get_columns("foo"),
+ [
+ {
+ "name": "id",
+ "type": testing.eq_type_affinity(sqltypes.INTEGER),
+ "nullable": False,
+ "default": None,
+ "autoincrement": False,
+ },
+ {
+ "name": "data",
+ "type": testing.eq_type_affinity(sqltypes.NullType),
+ "nullable": True,
+ "default": None,
+ "autoincrement": False,
+ },
+ ],
+ )
+
+ """
+
+ def __init__(self, target):
+ self.target = sqltypes.to_instance(target)
+
+ def __eq__(self, other):
+ return self.target._type_affinity is other._type_affinity
+
+ def __ne__(self, other):
+ return self.target._type_affinity is not other._type_affinity
+
+
+class eq_clause_element(object):
+ """Helper to compare SQL structures based on compare()"""
+
+ def __init__(self, target):
+ self.target = target
+
+ def __eq__(self, other):
+ return self.target.compare(other)
+
+ def __ne__(self, other):
+ return not self.target.compare(other)
+
+
+def _truncate_name(dialect, name):
+ if len(name) > dialect.max_identifier_length:
+ return (
+ name[0 : max(dialect.max_identifier_length - 6, 0)]
+ + "_"
+ + hex(hash(name) % 64)[2:]
+ )
+ else:
+ return name
+
+
+def pep435_enum(name):
+ # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+ __members__ = OrderedDict()
+
+ def __init__(self, name, value, alias=None):
+ self.name = name
+ self.value = value
+ self.__members__[name] = self
+ value_to_member[value] = self
+ setattr(self.__class__, name, self)
+ if alias:
+ self.__members__[alias] = self
+ setattr(self.__class__, alias, self)
+
+ value_to_member = {}
+
+ @classmethod
+ def get(cls, value):
+ return value_to_member[value]
+
+ someenum = type(
+ name,
+ (object,),
+ {"__members__": __members__, "__init__": __init__, "get": get},
+ )
+
+ # getframe() trick for pickling I don't understand courtesy
+ # Python namedtuple()
+ try:
+ module = sys._getframe(1).f_globals.get("__name__", "__main__")
+ except (AttributeError, ValueError):
+ pass
+ if module is not None:
+ someenum.__module__ = module
+
+ return someenum
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
new file mode 100644
index 0000000..30817e1
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -0,0 +1,13 @@
+from .test_cte import * # noqa
+from .test_ddl import * # noqa
+from .test_deprecations import * # noqa
+from .test_dialect import * # noqa
+from .test_insert import * # noqa
+from .test_reflection import * # noqa
+from .test_results import * # noqa
+from .test_rowcount import * # noqa
+from .test_select import * # noqa
+from .test_sequence import * # noqa
+from .test_types import * # noqa
+from .test_unicode_ddl import * # noqa
+from .test_update_delete import * # noqa
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
new file mode 100644
index 0000000..a94ee55
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -0,0 +1,204 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import ForeignKey
+from ... import Integer
+from ... import select
+from ... import String
+from ... import testing
+
+
+class CTETest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("ctes",)
+
+ run_inserts = "each"
+ run_deletes = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", ForeignKey("some_table.id")),
+ )
+
+ Table(
+ "some_other_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "d1", "parent_id": None},
+ {"id": 2, "data": "d2", "parent_id": 1},
+ {"id": 3, "data": "d3", "parent_id": 1},
+ {"id": 4, "data": "d4", "parent_id": 3},
+ {"id": 5, "data": "d5", "parent_id": 3},
+ ],
+ )
+
+ def test_select_nonrecursive_round_trip(self, connection):
+ some_table = self.tables.some_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ result = connection.execute(
+ select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
+ )
+ eq_(result.fetchall(), [("d4",)])
+
+ def test_select_recursive_round_trip(self, connection):
+ some_table = self.tables.some_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte", recursive=True)
+ )
+
+ cte_alias = cte.alias("c1")
+ st1 = some_table.alias()
+ # note that SQL Server requires this to be UNION ALL,
+ # can't be UNION
+ cte = cte.union_all(
+ select(st1).where(st1.c.id == cte_alias.c.parent_id)
+ )
+ result = connection.execute(
+ select(cte.c.data)
+ .where(cte.c.data != "d2")
+ .order_by(cte.c.data.desc())
+ )
+ eq_(
+ result.fetchall(),
+ [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
+ )
+
+ def test_insert_from_select_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(cte)
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.update_from
+ def test_update_from_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.update()
+ .values(parent_id=5)
+ .where(some_other_table.c.data == cte.c.data)
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None),
+ (2, "d2", 5),
+ (3, "d3", 5),
+ (4, "d4", 5),
+ (5, "d5", 3),
+ ],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.delete_from
+ def test_delete_from_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(1, "d1", None), (5, "d5", 3)],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ def test_delete_scalar_subq_round_trip(self, connection):
+
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data
+ == select(cte.c.data)
+ .where(cte.c.id == some_other_table.c.id)
+ .scalar_subquery()
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(1, "d1", None), (5, "d5", 3)],
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
new file mode 100644
index 0000000..b3fee55
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -0,0 +1,381 @@
+import random
+
+from . import testing
+from .. import config
+from .. import fixtures
+from .. import util
+from ..assertions import eq_
+from ..assertions import is_false
+from ..assertions import is_true
+from ..config import requirements
+from ..schema import Table
+from ... import CheckConstraint
+from ... import Column
+from ... import ForeignKeyConstraint
+from ... import Index
+from ... import inspect
+from ... import Integer
+from ... import schema
+from ... import String
+from ... import UniqueConstraint
+
+
+class TableDDLTest(fixtures.TestBase):
+ __backend__ = True
+
+ def _simple_fixture(self, schema=None):
+ return Table(
+ "test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ schema=schema,
+ )
+
+ def _underscore_fixture(self):
+ return Table(
+ "_test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("_data", String(50)),
+ )
+
+ def _table_index_fixture(self, schema=None):
+ table = self._simple_fixture(schema=schema)
+ idx = Index("test_index", table.c.data)
+ return table, idx
+
+ def _simple_roundtrip(self, table):
+ with config.db.begin() as conn:
+ conn.execute(table.insert().values((1, "some data")))
+ result = conn.execute(table.select())
+ eq_(result.first(), (1, "some data"))
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_create_table(self):
+ table = self._simple_fixture()
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.create_table
+ @requirements.schemas
+ @util.provide_metadata
+ def test_create_table_schema(self):
+ table = self._simple_fixture(schema=config.test_schema)
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.drop_table
+ @util.provide_metadata
+ def test_drop_table(self):
+ table = self._simple_fixture()
+ table.create(config.db, checkfirst=False)
+ table.drop(config.db, checkfirst=False)
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_underscore_names(self):
+ table = self._underscore_fixture()
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.comment_reflection
+ @util.provide_metadata
+ def test_add_table_comment(self, connection):
+ table = self._simple_fixture()
+ table.create(connection, checkfirst=False)
+ table.comment = "a comment"
+ connection.execute(schema.SetTableComment(table))
+ eq_(
+ inspect(connection).get_table_comment("test_table"),
+ {"text": "a comment"},
+ )
+
+ @requirements.comment_reflection
+ @util.provide_metadata
+ def test_drop_table_comment(self, connection):
+ table = self._simple_fixture()
+ table.create(connection, checkfirst=False)
+ table.comment = "a comment"
+ connection.execute(schema.SetTableComment(table))
+ connection.execute(schema.DropTableComment(table))
+ eq_(
+ inspect(connection).get_table_comment("test_table"), {"text": None}
+ )
+
+ @requirements.table_ddl_if_exists
+ @util.provide_metadata
+ def test_create_table_if_not_exists(self, connection):
+ table = self._simple_fixture()
+
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+
+ is_true(inspect(connection).has_table("test_table"))
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+
+ @requirements.index_ddl_if_exists
+ @util.provide_metadata
+ def test_create_index_if_not_exists(self, connection):
+ table, idx = self._table_index_fixture()
+
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+ is_true(inspect(connection).has_table("test_table"))
+ is_false(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.CreateIndex(idx, if_not_exists=True))
+
+ is_true(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.CreateIndex(idx, if_not_exists=True))
+
+ @requirements.table_ddl_if_exists
+ @util.provide_metadata
+ def test_drop_table_if_exists(self, connection):
+ table = self._simple_fixture()
+
+ table.create(connection)
+
+ is_true(inspect(connection).has_table("test_table"))
+
+ connection.execute(schema.DropTable(table, if_exists=True))
+
+ is_false(inspect(connection).has_table("test_table"))
+
+ connection.execute(schema.DropTable(table, if_exists=True))
+
+ @requirements.index_ddl_if_exists
+ @util.provide_metadata
+ def test_drop_index_if_exists(self, connection):
+ table, idx = self._table_index_fixture()
+
+ table.create(connection)
+
+ is_true(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.DropIndex(idx, if_exists=True))
+
+ is_false(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.DropIndex(idx, if_exists=True))
+
+
+class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
+ pass
+
+
+class LongNameBlowoutTest(fixtures.TestBase):
+ """test the creation of a variety of DDL structures and ensure
+ label length limits pass on backends
+
+ """
+
+ __backend__ = True
+
+ def fk(self, metadata, connection):
+ convention = {
+ "fk": "foreign_key_%(table_name)s_"
+ "%(column_0_N_name)s_"
+ "%(referred_table_name)s_"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(20))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ cons = ForeignKeyConstraint(
+ ["aid"], ["a_things_with_stuff.id_long_column_name"]
+ )
+ Table(
+ "b_related_things_of_value",
+ metadata,
+ Column(
+ "aid",
+ ),
+ cons,
+ test_needs_fk=True,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+
+ if testing.requires.foreign_key_constraint_name_reflection.enabled:
+ insp = inspect(connection)
+ fks = insp.get_foreign_keys("b_related_things_of_value")
+ reflected_name = fks[0]["name"]
+
+ return actual_name, reflected_name
+ else:
+ return actual_name, None
+
+ def pk(self, metadata, connection):
+ convention = {
+ "pk": "primary_key_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ a = Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer, primary_key=True),
+ )
+ cons = a.primary_key
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ pk = insp.get_pk_constraint("a_things_with_stuff")
+ reflected_name = pk["name"]
+ return actual_name, reflected_name
+
+ def ix(self, metadata, connection):
+ convention = {
+ "ix": "index_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ a = Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer),
+ )
+ cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ ix = insp.get_indexes("a_things_with_stuff")
+ reflected_name = ix[0]["name"]
+ return actual_name, reflected_name
+
+ def uq(self, metadata, connection):
+ convention = {
+ "uq": "unique_constraint_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer),
+ cons,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ uq = insp.get_unique_constraints("a_things_with_stuff")
+ reflected_name = uq[0]["name"]
+ return actual_name, reflected_name
+
+ def ck(self, metadata, connection):
+ convention = {
+ "ck": "check_constraint_%(table_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ cons = CheckConstraint("some_long_column_name > 5")
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("some_long_column_name", Integer),
+ cons,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ ck = insp.get_check_constraints("a_things_with_stuff")
+ reflected_name = ck[0]["name"]
+ return actual_name, reflected_name
+
+ @testing.combinations(
+ ("fk",),
+ ("pk",),
+ ("ix",),
+ ("ck", testing.requires.check_constraint_reflection.as_skips()),
+ ("uq", testing.requires.unique_constraint_reflection.as_skips()),
+ argnames="type_",
+ )
+ def test_long_convention_name(self, type_, metadata, connection):
+ actual_name, reflected_name = getattr(self, type_)(
+ metadata, connection
+ )
+
+ assert len(actual_name) > 255
+
+ if reflected_name is not None:
+ overlap = actual_name[0 : len(reflected_name)]
+ if len(overlap) < len(actual_name):
+ eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
+ else:
+ eq_(overlap, reflected_name)
+
+
+__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")
diff --git a/lib/sqlalchemy/testing/suite/test_deprecations.py b/lib/sqlalchemy/testing/suite/test_deprecations.py
new file mode 100644
index 0000000..b36162f
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_deprecations.py
@@ -0,0 +1,145 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import select
+from ... import testing
+from ... import union
+
+
+class DeprecatedCompoundSelectTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ],
+ )
+
+ def _assert_result(self, conn, select, result, params=()):
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_plain_union(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ # note we've had to remove one use case entirely, which is this
+ # one. the Select gets its FROMS from the WHERE clause and the
+ # columns clause, but not the ORDER BY, which means the old ".c" system
+ # allowed you to "order_by(s.c.foo)" to get an unnamed column in the
+ # ORDER BY without adding the SELECT into the FROM and breaking the
+ # query. Users will have to adjust for this use case if they were doing
+ # it before.
+ def _dont_test_select_from_plain_union(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2).alias().select()
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.order_by_col_from_union
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_wo_limit_offset
+ def test_order_by_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_distinct_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).distinct()
+ s2 = select(table).where(table.c.id == 3).distinct()
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_limit_offset_aliased_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = (
+ select(table)
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select(table)
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
new file mode 100644
index 0000000..c2c17d0
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -0,0 +1,361 @@
+#! coding: utf-8
+
+from . import testing
+from .. import assert_raises
+from .. import config
+from .. import engines
+from .. import eq_
+from .. import fixtures
+from .. import ne_
+from .. import provide_metadata
+from ..config import requirements
+from ..provision import set_default_schema_on_connection
+from ..schema import Column
+from ..schema import Table
+from ... import bindparam
+from ... import event
+from ... import exc
+from ... import Integer
+from ... import literal_column
+from ... import select
+from ... import String
+from ...util import compat
+
+
+class ExceptionTest(fixtures.TablesTest):
+ """Test basic exception wrapping.
+
+ DBAPIs vary a lot in exception behavior so to actually anticipate
+ specific exceptions from real round trips, we need to be conservative.
+
+ """
+
+ run_deletes = "each"
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ @requirements.duplicate_key_raises_integrity_error
+ def test_integrity_error(self):
+
+ with config.db.connect() as conn:
+
+ trans = conn.begin()
+ conn.execute(
+ self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
+ )
+
+ assert_raises(
+ exc.IntegrityError,
+ conn.execute,
+ self.tables.manual_pk.insert(),
+ {"id": 1, "data": "d1"},
+ )
+
+ trans.rollback()
+
+ def test_exception_with_non_ascii(self):
+ with config.db.connect() as conn:
+ try:
+ # try to create an error message that likely has non-ascii
+ # characters in the DBAPI's message string. unfortunately
+ # there's no way to make this happen with some drivers like
+ # mysqlclient, pymysql. this at least does produce a non-
+ # ascii error message for cx_oracle, psycopg2
+ conn.execute(select(literal_column(u"méil")))
+ assert False
+ except exc.DBAPIError as err:
+ err_str = str(err)
+
+ assert str(err.orig) in str(err)
+
+ # test that we are actually getting string on Py2k, unicode
+ # on Py3k.
+ if compat.py2k:
+ assert isinstance(err_str, str)
+ else:
+ assert isinstance(err_str, str)
+
+
+class IsolationLevelTest(fixtures.TestBase):
+ __backend__ = True
+
+ __requires__ = ("isolation_level",)
+
+ def _get_non_default_isolation_level(self):
+ levels = requirements.get_isolation_levels(config)
+
+ default = levels["default"]
+ supported = levels["supported"]
+
+ s = set(supported).difference(["AUTOCOMMIT", default])
+ if s:
+ return s.pop()
+ else:
+ config.skip_test("no non-default isolation level available")
+
+ def test_default_isolation_level(self):
+ eq_(
+ config.db.dialect.default_isolation_level,
+ requirements.get_isolation_levels(config)["default"],
+ )
+
+ def test_non_default_isolation_level(self):
+ non_default = self._get_non_default_isolation_level()
+
+ with config.db.connect() as conn:
+ existing = conn.get_isolation_level()
+
+ ne_(existing, non_default)
+
+ conn.execution_options(isolation_level=non_default)
+
+ eq_(conn.get_isolation_level(), non_default)
+
+ conn.dialect.reset_isolation_level(conn.connection)
+
+ eq_(conn.get_isolation_level(), existing)
+
+ def test_all_levels(self):
+ levels = requirements.get_isolation_levels(config)
+
+ all_levels = levels["supported"]
+
+ for level in set(all_levels).difference(["AUTOCOMMIT"]):
+ with config.db.connect() as conn:
+ conn.execution_options(isolation_level=level)
+
+ eq_(conn.get_isolation_level(), level)
+
+ trans = conn.begin()
+ trans.rollback()
+
+ eq_(conn.get_isolation_level(), level)
+
+ with config.db.connect() as conn:
+ eq_(
+ conn.get_isolation_level(),
+ levels["default"],
+ )
+
+
+class AutocommitIsolationTest(fixtures.TablesTest):
+
+ run_deletes = "each"
+
+ __requires__ = ("autocommit",)
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ test_needs_acid=True,
+ )
+
+ def _test_conn_autocommits(self, conn, autocommit):
+ trans = conn.begin()
+ conn.execute(
+ self.tables.some_table.insert(), {"id": 1, "data": "some data"}
+ )
+ trans.rollback()
+
+ eq_(
+ conn.scalar(select(self.tables.some_table.c.id)),
+ 1 if autocommit else None,
+ )
+
+ with conn.begin():
+ conn.execute(self.tables.some_table.delete())
+
+ def test_autocommit_on(self, connection_no_trans):
+ conn = connection_no_trans
+ c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
+ self._test_conn_autocommits(c2, True)
+
+ c2.dialect.reset_isolation_level(c2.connection)
+
+ self._test_conn_autocommits(conn, False)
+
+ def test_autocommit_off(self, connection_no_trans):
+ conn = connection_no_trans
+ self._test_conn_autocommits(conn, False)
+
+ def test_turn_autocommit_off_via_default_iso_level(
+ self, connection_no_trans
+ ):
+ conn = connection_no_trans
+ conn = conn.execution_options(isolation_level="AUTOCOMMIT")
+ self._test_conn_autocommits(conn, True)
+
+ conn.execution_options(
+ isolation_level=requirements.get_isolation_levels(config)[
+ "default"
+ ]
+ )
+ self._test_conn_autocommits(conn, False)
+
+
+class EscapingTest(fixtures.TestBase):
+ @provide_metadata
+ def test_percent_sign_round_trip(self):
+ """test that the DBAPI accommodates for escaped / nonescaped
+ percent signs in a way that matches the compiler
+
+ """
+ m = self.metadata
+ t = Table("t", m, Column("data", String(50)))
+ t.create(config.db)
+ with config.db.begin() as conn:
+ conn.execute(t.insert(), dict(data="some % value"))
+ conn.execute(t.insert(), dict(data="some %% other value"))
+
+ eq_(
+ conn.scalar(
+ select(t.c.data).where(
+ t.c.data == literal_column("'some % value'")
+ )
+ ),
+ "some % value",
+ )
+
+ eq_(
+ conn.scalar(
+ select(t.c.data).where(
+ t.c.data == literal_column("'some %% other value'")
+ )
+ ),
+ "some %% other value",
+ )
+
+
+class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
+ __backend__ = True
+
+ __requires__ = ("default_schema_name_switch",)
+
+ def test_control_case(self):
+ default_schema_name = config.db.dialect.default_schema_name
+
+ eng = engines.testing_engine()
+ with eng.connect():
+ pass
+
+ eq_(eng.dialect.default_schema_name, default_schema_name)
+
+ def test_wont_work_wo_insert(self):
+ default_schema_name = config.db.dialect.default_schema_name
+
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect")
+ def on_connect(dbapi_connection, connection_record):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, default_schema_name)
+
+ def test_schema_change_on_connect(self):
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect", insert=True)
+ def on_connect(dbapi_connection, connection_record):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, config.test_schema)
+
+ def test_schema_change_works_w_transactions(self):
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect", insert=True)
+ def on_connect(dbapi_connection, *arg):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ trans = conn.begin()
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+ trans.rollback()
+
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, config.test_schema)
+
+
+class FutureWeCanSetDefaultSchemaWEventsTest(
+ fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
+):
+ pass
+
+
+class DifficultParametersTest(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.combinations(
+ ("boring",),
+ ("per cent",),
+ ("per % cent",),
+ ("%percent",),
+ ("par(ens)",),
+ ("percent%(ens)yah",),
+ ("col:ons",),
+ ("more :: %colons%",),
+ ("/slashes/",),
+ ("more/slashes",),
+ ("q?marks",),
+ ("1param",),
+ ("1col:on",),
+ argnames="name",
+ )
+ def test_round_trip(self, name, connection, metadata):
+ t = Table(
+ "t",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column(name, String(50), nullable=False),
+ )
+
+ # table is created
+ t.create(connection)
+
+ # automatic param generated by insert
+ connection.execute(t.insert().values({"id": 1, name: "some name"}))
+
+ # automatic param generated by criteria, plus selecting the column
+ stmt = select(t.c[name]).where(t.c[name] == "some name")
+
+ eq_(connection.scalar(stmt), "some name")
+
+ # use the name in a param explicitly
+ stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
+
+ row = connection.execute(stmt, {name: "some name"}).first()
+
+ # name works as the key from cursor.description
+ eq_(row._mapping[name], "some name")
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
new file mode 100644
index 0000000..3c22f50
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -0,0 +1,367 @@
+from .. import config
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import select
+from ... import String
+
+
+class LastrowidTest(fixtures.TablesTest):
+ run_deletes = "each"
+
+ __backend__ = True
+
+ __requires__ = "implements_get_lastrowid", "autoincrement_insert"
+
+ __engine_options__ = {"implicit_returning": False}
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(
+ row,
+ (
+ conn.dialect.default_sequence_base,
+ "some data",
+ ),
+ )
+
+ def test_autoincrement_on_insert(self, connection):
+
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+ def test_last_inserted_id(self, connection):
+
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(r.inserted_primary_key, (pk,))
+
+ @requirements.dbapi_lastrowid
+ def test_native_lastrowid_autoinc(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ lastrowid = r.lastrowid
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(lastrowid, pk)
+
+
+class InsertBehaviorTest(fixtures.TablesTest):
+ run_deletes = "each"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+ Table(
+ "includes_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ Column("x", Integer, default=5),
+ Column(
+ "y",
+ Integer,
+ default=literal_column("2", type_=Integer) + literal(2),
+ ),
+ )
+
+ @requirements.autoincrement_insert
+ def test_autoclose_on_insert(self):
+ if requirements.returning.enabled:
+ engine = engines.testing_engine(
+ options={"implicit_returning": False}
+ )
+ else:
+ engine = config.db
+
+ with engine.begin() as conn:
+ r = conn.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ assert r._soft_closed
+ assert not r.closed
+ assert r.is_insert
+
+ # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
+ # an insert where the PK was taken from a row that the dialect
+ # selected, as is the case for mssql/pyodbc, will still report
+ # returns_rows as true because there's a cursor description. in that
+ # case, the row had to have been consumed at least.
+ assert not r.returns_rows or r.fetchone() is None
+
+ @requirements.returning
+ def test_autoclose_on_insert_implicit_returning(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ assert r._soft_closed
+ assert not r.closed
+ assert r.is_insert
+
+ # note we are experimenting with having this be True
+ # as of I8091919d45421e3f53029b8660427f844fee0228 .
+ # implicit returning has fetched the row, but it still is a
+ # "returns rows"
+ assert r.returns_rows
+
+ # and we should be able to fetchone() on it, we just get no row
+ eq_(r.fetchone(), None)
+
+ # and the keys, etc.
+ eq_(r.keys(), ["id"])
+
+ # but the dialect took in the row already. not really sure
+ # what the best behavior is.
+
+ @requirements.empty_inserts
+ def test_empty_insert(self, connection):
+ r = connection.execute(self.tables.autoinc_pk.insert())
+ assert r._soft_closed
+ assert not r.closed
+
+ r = connection.execute(
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
+ )
+ eq_(len(r.all()), 1)
+
+ @requirements.empty_inserts_executemany
+ def test_empty_insert_multiple(self, connection):
+ r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
+ assert r._soft_closed
+ assert not r.closed
+
+ r = connection.execute(
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
+ )
+
+ eq_(len(r.all()), 3)
+
+ @requirements.insert_from_select
+ def test_insert_from_select_autoinc(self, connection):
+ src_table = self.tables.manual_pk
+ dest_table = self.tables.autoinc_pk
+ connection.execute(
+ src_table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ result = connection.execute(
+ dest_table.insert().from_select(
+ ("data",),
+ select(src_table.c.data).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(result.inserted_primary_key, (None,))
+
+ result = connection.execute(
+ select(dest_table.c.data).order_by(dest_table.c.data)
+ )
+ eq_(result.fetchall(), [("data2",), ("data3",)])
+
+ @requirements.insert_from_select
+ def test_insert_from_select_autoinc_no_rows(self, connection):
+ src_table = self.tables.manual_pk
+ dest_table = self.tables.autoinc_pk
+
+ result = connection.execute(
+ dest_table.insert().from_select(
+ ("data",),
+ select(src_table.c.data).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+ eq_(result.inserted_primary_key, (None,))
+
+ result = connection.execute(
+ select(dest_table.c.data).order_by(dest_table.c.data)
+ )
+
+ eq_(result.fetchall(), [])
+
+ @requirements.insert_from_select
+ def test_insert_from_select(self, connection):
+ table = self.tables.manual_pk
+ connection.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ connection.execute(
+ table.insert()
+ .inline()
+ .from_select(
+ ("id", "data"),
+ select(table.c.id + 5, table.c.data).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(
+ connection.execute(
+ select(table.c.data).order_by(table.c.data)
+ ).fetchall(),
+ [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
+ )
+
+ @requirements.insert_from_select
+ def test_insert_from_select_with_defaults(self, connection):
+ table = self.tables.includes_defaults
+ connection.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ connection.execute(
+ table.insert()
+ .inline()
+ .from_select(
+ ("id", "data"),
+ select(table.c.id + 5, table.c.data).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(
+ connection.execute(
+ select(table).order_by(table.c.data, table.c.id)
+ ).fetchall(),
+ [
+ (1, "data1", 5, 4),
+ (2, "data2", 5, 4),
+ (7, "data2", 5, 4),
+ (3, "data3", 5, 4),
+ (8, "data3", 5, 4),
+ ],
+ )
+
+
+class ReturningTest(fixtures.TablesTest):
+ run_create_tables = "each"
+ __requires__ = "returning", "autoincrement_insert"
+ __backend__ = True
+
+ __engine_options__ = {"implicit_returning": True}
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(
+ row,
+ (
+ conn.dialect.default_sequence_base,
+ "some data",
+ ),
+ )
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ @requirements.fetch_rows_post_commit
+ def test_explicit_returning_pk_autocommit(self, connection):
+ table = self.tables.autoinc_pk
+ r = connection.execute(
+ table.insert().returning(table.c.id), dict(data="some data")
+ )
+ pk = r.first()[0]
+ fetched_pk = connection.scalar(select(table.c.id))
+ eq_(fetched_pk, pk)
+
+ def test_explicit_returning_pk_no_autocommit(self, connection):
+ table = self.tables.autoinc_pk
+ r = connection.execute(
+ table.insert().returning(table.c.id), dict(data="some data")
+ )
+ pk = r.first()[0]
+ fetched_pk = connection.scalar(select(table.c.id))
+ eq_(fetched_pk, pk)
+
+ def test_autoincrement_on_insert_implicit_returning(self, connection):
+
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+ def test_last_inserted_id_implicit_returning(self, connection):
+
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(r.inserted_primary_key, (pk,))
+
+
+__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
new file mode 100644
index 0000000..459a4d8
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -0,0 +1,1738 @@
+import operator
+import re
+
+import sqlalchemy as sa
+from .. import config
+from .. import engines
+from .. import eq_
+from .. import expect_warnings
+from .. import fixtures
+from .. import is_
+from ..provision import get_temp_table_name
+from ..provision import temp_table_keyword_args
+from ..schema import Column
+from ..schema import Table
+from ... import event
+from ... import ForeignKey
+from ... import func
+from ... import Identity
+from ... import inspect
+from ... import Integer
+from ... import MetaData
+from ... import String
+from ... import testing
+from ... import types as sql_types
+from ...schema import DDL
+from ...schema import Index
+from ...sql.elements import quoted_name
+from ...sql.schema import BLANK_SCHEMA
+from ...testing import is_false
+from ...testing import is_true
+
+
+metadata, users = None, None
+
+
+class HasTableTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ if testing.requires.schemas.enabled:
+ Table(
+ "test_table_s",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ schema=config.test_schema,
+ )
+
+ def test_has_table(self):
+ with config.db.begin() as conn:
+ is_true(config.db.dialect.has_table(conn, "test_table"))
+ is_false(config.db.dialect.has_table(conn, "test_table_s"))
+ is_false(config.db.dialect.has_table(conn, "nonexistent_table"))
+
+ @testing.requires.schemas
+ def test_has_table_schema(self):
+ with config.db.begin() as conn:
+ is_false(
+ config.db.dialect.has_table(
+ conn, "test_table", schema=config.test_schema
+ )
+ )
+ is_true(
+ config.db.dialect.has_table(
+ conn, "test_table_s", schema=config.test_schema
+ )
+ )
+ is_false(
+ config.db.dialect.has_table(
+ conn, "nonexistent_table", schema=config.test_schema
+ )
+ )
+
+
+class HasIndexTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ tt = Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Index("my_idx", tt.c.data)
+
+ if testing.requires.schemas.enabled:
+ tt = Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ schema=config.test_schema,
+ )
+ Index("my_idx_s", tt.c.data)
+
+ def test_has_index(self):
+ with config.db.begin() as conn:
+ assert config.db.dialect.has_index(conn, "test_table", "my_idx")
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "my_idx_s"
+ )
+ assert not config.db.dialect.has_index(
+ conn, "nonexistent_table", "my_idx"
+ )
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "nonexistent_idx"
+ )
+
+ @testing.requires.schemas
+ def test_has_index_schema(self):
+ with config.db.begin() as conn:
+ assert config.db.dialect.has_index(
+ conn, "test_table", "my_idx_s", schema=config.test_schema
+ )
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "my_idx", schema=config.test_schema
+ )
+ assert not config.db.dialect.has_index(
+ conn,
+ "nonexistent_table",
+ "my_idx_s",
+ schema=config.test_schema,
+ )
+ assert not config.db.dialect.has_index(
+ conn,
+ "test_table",
+ "nonexistent_idx_s",
+ schema=config.test_schema,
+ )
+
+
+class QuotedNameArgumentTest(fixtures.TablesTest):
+ run_create_tables = "once"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "quote ' one",
+ metadata,
+ Column("id", Integer),
+ Column("name", String(50)),
+ Column("data", String(50)),
+ Column("related_id", Integer),
+ sa.PrimaryKeyConstraint("id", name="pk quote ' one"),
+ sa.Index("ix quote ' one", "name"),
+ sa.UniqueConstraint(
+ "data",
+ name="uq quote' one",
+ ),
+ sa.ForeignKeyConstraint(
+ ["id"], ["related.id"], name="fk quote ' one"
+ ),
+ sa.CheckConstraint("name != 'foo'", name="ck quote ' one"),
+ comment=r"""quote ' one comment""",
+ test_needs_fk=True,
+ )
+
+ if testing.requires.symbol_names_w_double_quote.enabled:
+ Table(
+ 'quote " two',
+ metadata,
+ Column("id", Integer),
+ Column("name", String(50)),
+ Column("data", String(50)),
+ Column("related_id", Integer),
+ sa.PrimaryKeyConstraint("id", name='pk quote " two'),
+ sa.Index('ix quote " two', "name"),
+ sa.UniqueConstraint(
+ "data",
+ name='uq quote" two',
+ ),
+ sa.ForeignKeyConstraint(
+ ["id"], ["related.id"], name='fk quote " two'
+ ),
+ sa.CheckConstraint("name != 'foo'", name='ck quote " two '),
+ comment=r"""quote " two comment""",
+ test_needs_fk=True,
+ )
+
+ Table(
+ "related",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("related", Integer),
+ test_needs_fk=True,
+ )
+
+ if testing.requires.view_column_reflection.enabled:
+
+ if testing.requires.symbol_names_w_double_quote.enabled:
+ names = [
+ "quote ' one",
+ 'quote " two',
+ ]
+ else:
+ names = [
+ "quote ' one",
+ ]
+ for name in names:
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (
+ config.db.dialect.identifier_preparer.quote(
+ "view %s" % name
+ ),
+ config.db.dialect.identifier_preparer.quote(name),
+ )
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata,
+ "before_drop",
+ DDL(
+ "DROP VIEW %s"
+ % config.db.dialect.identifier_preparer.quote(
+ "view %s" % name
+ )
+ ),
+ )
+
+ def quote_fixtures(fn):
+ return testing.combinations(
+ ("quote ' one",),
+ ('quote " two', testing.requires.symbol_names_w_double_quote),
+ )(fn)
+
+ @quote_fixtures
+ def test_get_table_options(self, name):
+ insp = inspect(config.db)
+
+ insp.get_table_options(name)
+
+ @quote_fixtures
+ @testing.requires.view_column_reflection
+ def test_get_view_definition(self, name):
+ insp = inspect(config.db)
+ assert insp.get_view_definition("view %s" % name)
+
+ @quote_fixtures
+ def test_get_columns(self, name):
+ insp = inspect(config.db)
+ assert insp.get_columns(name)
+
+ @quote_fixtures
+ def test_get_pk_constraint(self, name):
+ insp = inspect(config.db)
+ assert insp.get_pk_constraint(name)
+
+ @quote_fixtures
+ def test_get_foreign_keys(self, name):
+ insp = inspect(config.db)
+ assert insp.get_foreign_keys(name)
+
+ @quote_fixtures
+ def test_get_indexes(self, name):
+ insp = inspect(config.db)
+ assert insp.get_indexes(name)
+
+ @quote_fixtures
+ @testing.requires.unique_constraint_reflection
+ def test_get_unique_constraints(self, name):
+ insp = inspect(config.db)
+ assert insp.get_unique_constraints(name)
+
+ @quote_fixtures
+ @testing.requires.comment_reflection
+ def test_get_table_comment(self, name):
+ insp = inspect(config.db)
+ assert insp.get_table_comment(name)
+
+ @quote_fixtures
+ @testing.requires.check_constraint_reflection
+ def test_get_check_constraints(self, name):
+ insp = inspect(config.db)
+ assert insp.get_check_constraints(name)
+
+
+class ComponentReflectionTest(fixtures.TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+
+ @classmethod
+ def setup_bind(cls):
+ if config.requirements.independent_connections.enabled:
+ from sqlalchemy import pool
+
+ return engines.testing_engine(
+ options=dict(poolclass=pool.StaticPool, scope="class"),
+ )
+ else:
+ return config.db
+
+ @classmethod
+ def define_tables(cls, metadata):
+ cls.define_reflected_tables(metadata, None)
+ if testing.requires.schemas.enabled:
+ cls.define_reflected_tables(metadata, testing.config.test_schema)
+
+ @classmethod
+ def define_reflected_tables(cls, metadata, schema):
+ if schema:
+ schema_prefix = schema + "."
+ else:
+ schema_prefix = ""
+
+ if testing.requires.self_referential_foreign_keys.enabled:
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ Column(
+ "parent_user_id",
+ sa.Integer,
+ sa.ForeignKey(
+ "%susers.user_id" % schema_prefix, name="user_id_fk"
+ ),
+ ),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ else:
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ Table(
+ "dingalings",
+ metadata,
+ Column("dingaling_id", sa.Integer, primary_key=True),
+ Column(
+ "address_id",
+ sa.Integer,
+ sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ ),
+ Column("data", sa.String(30)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "email_addresses",
+ metadata,
+ Column("address_id", sa.Integer),
+ Column(
+ "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
+ ),
+ Column("email_address", sa.String(20)),
+ sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "comment_test",
+ metadata,
+ Column("id", sa.Integer, primary_key=True, comment="id comment"),
+ Column("data", sa.String(20), comment="data % comment"),
+ Column(
+ "d2",
+ sa.String(20),
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ schema=schema,
+ comment=r"""the test % ' " \ table comment""",
+ )
+
+ if testing.requires.cross_schema_fk_reflection.enabled:
+ if schema is None:
+ Table(
+ "local_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
+ Column(
+ "remote_id",
+ ForeignKey(
+ "%s.remote_table_2.id" % testing.config.test_schema
+ ),
+ ),
+ test_needs_fk=True,
+ schema=config.db.dialect.default_schema_name,
+ )
+ else:
+ Table(
+ "remote_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column(
+ "local_id",
+ ForeignKey(
+ "%s.local_table.id"
+ % config.db.dialect.default_schema_name
+ ),
+ ),
+ Column("data", sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "remote_table_2",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ if testing.requires.index_reflection.enabled:
+ cls.define_index(metadata, users)
+
+ if not schema:
+ # test_needs_fk is at the moment to force MySQL InnoDB
+ noncol_idx_test_nopk = Table(
+ "noncol_idx_test_nopk",
+ metadata,
+ Column("q", sa.String(5)),
+ test_needs_fk=True,
+ )
+
+ noncol_idx_test_pk = Table(
+ "noncol_idx_test_pk",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("q", sa.String(5)),
+ test_needs_fk=True,
+ )
+
+ if testing.requires.indexes_with_ascdesc.enabled:
+ Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
+ Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
+
+ if testing.requires.view_column_reflection.enabled:
+ cls.define_views(metadata, schema)
+ if not schema and testing.requires.temp_table_reflection.enabled:
+ cls.define_temp_tables(metadata)
+
+ @classmethod
+ def define_temp_tables(cls, metadata):
+ kw = temp_table_keyword_args(config, config.db)
+ table_name = get_temp_table_name(
+ config, config.db, "user_tmp_%s" % config.ident
+ )
+ user_tmp = Table(
+ table_name,
+ metadata,
+ Column("id", sa.INT, primary_key=True),
+ Column("name", sa.VARCHAR(50)),
+ Column("foo", sa.INT),
+ # disambiguate temp table unique constraint names. this is
+ # pretty arbitrary for a generic dialect however we are doing
+ # it to suit SQL Server which will produce name conflicts for
+ # unique constraints created against temp tables in different
+ # databases.
+ # https://www.arbinada.com/en/node/1645
+ sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident),
+ sa.Index("user_tmp_ix", "foo"),
+ **kw
+ )
+ if (
+ testing.requires.view_reflection.enabled
+ and testing.requires.temporary_views.enabled
+ ):
+ event.listen(
+ user_tmp,
+ "after_create",
+ DDL(
+ "create temporary view user_tmp_v as "
+ "select * from user_tmp_%s" % config.ident
+ ),
+ )
+ event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
+
+ @classmethod
+ def define_index(cls, metadata, users):
+ Index("users_t_idx", users.c.test1, users.c.test2)
+ Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1)
+
+ @classmethod
+ def define_views(cls, metadata, schema):
+ for table_name in ("users", "email_addresses"):
+ fullname = table_name
+ if schema:
+ fullname = "%s.%s" % (schema, table_name)
+ view_name = fullname + "_v"
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (
+ view_name,
+ fullname,
+ )
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
+ )
+
+ @testing.requires.schema_reflection
+ def test_get_schema_names(self):
+ insp = inspect(self.bind)
+
+ self.assert_(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_get_schema_names_w_translate_map(self, connection):
+ """test #7300"""
+
+ connection = connection.execution_options(
+ schema_translate_map={
+ "foo": "bar",
+ BLANK_SCHEMA: testing.config.test_schema,
+ }
+ )
+ insp = inspect(connection)
+
+ self.assert_(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_dialect_initialize(self):
+ engine = engines.testing_engine()
+ inspect(engine)
+ assert hasattr(engine.dialect, "default_schema_name")
+
+ @testing.requires.schema_reflection
+ def test_get_default_schema_name(self):
+ insp = inspect(self.bind)
+ eq_(insp.default_schema_name, self.bind.dialect.default_schema_name)
+
+ @testing.requires.foreign_key_constraint_reflection
+ @testing.combinations(
+ (None, True, False, False),
+ (None, True, False, True, testing.requires.schemas),
+ ("foreign_key", True, False, False),
+ (None, False, True, False),
+ (None, False, True, True, testing.requires.schemas),
+ (None, True, True, False),
+ (None, True, True, True, testing.requires.schemas),
+ argnames="order_by,include_plain,include_views,use_schema",
+ )
+ def test_get_table_names(
+ self, connection, order_by, include_plain, include_views, use_schema
+ ):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ _ignore_tables = [
+ "comment_test",
+ "noncol_idx_test_pk",
+ "noncol_idx_test_nopk",
+ "local_table",
+ "remote_table",
+ "remote_table_2",
+ ]
+
+ insp = inspect(connection)
+
+ if include_views:
+ table_names = insp.get_view_names(schema)
+ table_names.sort()
+ answer = ["email_addresses_v", "users_v"]
+ eq_(sorted(table_names), answer)
+
+ if include_plain:
+ if order_by:
+ tables = [
+ rec[0]
+ for rec in insp.get_sorted_table_and_fkc_names(schema)
+ if rec[0]
+ ]
+ else:
+ tables = insp.get_table_names(schema)
+ table_names = [t for t in tables if t not in _ignore_tables]
+
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
+ eq_(table_names, answer)
+ else:
+ answer = ["dingalings", "email_addresses", "users"]
+ eq_(sorted(table_names), answer)
+
+ @testing.requires.temp_table_names
+ def test_get_temp_table_names(self):
+ insp = inspect(self.bind)
+ temp_table_names = insp.get_temp_table_names()
+ eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident])
+
+ @testing.requires.view_reflection
+ @testing.requires.temp_table_names
+ @testing.requires.temporary_views
+ def test_get_temp_view_names(self):
+ insp = inspect(self.bind)
+ temp_table_names = insp.get_temp_view_names()
+ eq_(sorted(temp_table_names), ["user_tmp_v"])
+
+ @testing.requires.comment_reflection
+ def test_get_comments(self):
+ self._test_get_comments()
+
+ @testing.requires.comment_reflection
+ @testing.requires.schemas
+ def test_get_comments_with_schema(self):
+ self._test_get_comments(testing.config.test_schema)
+
+ def _test_get_comments(self, schema=None):
+ insp = inspect(self.bind)
+
+ eq_(
+ insp.get_table_comment("comment_test", schema=schema),
+ {"text": r"""the test % ' " \ table comment"""},
+ )
+
+ eq_(insp.get_table_comment("users", schema=schema), {"text": None})
+
+ eq_(
+ [
+ {"name": rec["name"], "comment": rec["comment"]}
+ for rec in insp.get_columns("comment_test", schema=schema)
+ ],
+ [
+ {"comment": "id comment", "name": "id"},
+ {"comment": "data % comment", "name": "data"},
+ {
+ "comment": (
+ r"""Comment types type speedily ' " \ '' Fun!"""
+ ),
+ "name": "d2",
+ },
+ ],
+ )
+
+ @testing.combinations(
+ (False, False),
+ (False, True, testing.requires.schemas),
+ (True, False, testing.requires.view_reflection),
+ (
+ True,
+ True,
+ testing.requires.schemas + testing.requires.view_reflection,
+ ),
+ argnames="use_views,use_schema",
+ )
+ def test_get_columns(self, connection, use_views, use_schema):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ users, addresses = (self.tables.users, self.tables.email_addresses)
+ if use_views:
+ table_names = ["users_v", "email_addresses_v"]
+ else:
+ table_names = ["users", "email_addresses"]
+
+ insp = inspect(connection)
+ for table_name, table in zip(table_names, (users, addresses)):
+ schema_name = schema
+ cols = insp.get_columns(table_name, schema=schema_name)
+ self.assert_(len(cols) > 0, len(cols))
+
+ # should be in order
+
+ for i, col in enumerate(table.columns):
+ eq_(col.name, cols[i]["name"])
+ ctype = cols[i]["type"].__class__
+ ctype_def = col.type
+ if isinstance(ctype_def, sa.types.TypeEngine):
+ ctype_def = ctype_def.__class__
+
+ # Oracle returns Date for DateTime.
+
+ if testing.against("oracle") and ctype_def in (
+ sql_types.Date,
+ sql_types.DateTime,
+ ):
+ ctype_def = sql_types.Date
+
+ # assert that the desired type and return type share
+ # a base within one of the generic types.
+
+ self.assert_(
+ len(
+ set(ctype.__mro__)
+ .intersection(ctype_def.__mro__)
+ .intersection(
+ [
+ sql_types.Integer,
+ sql_types.Numeric,
+ sql_types.DateTime,
+ sql_types.Date,
+ sql_types.Time,
+ sql_types.String,
+ sql_types._Binary,
+ ]
+ )
+ )
+ > 0,
+ "%s(%s), %s(%s)"
+ % (col.name, col.type, cols[i]["name"], ctype),
+ )
+
+ if not col.primary_key:
+ assert cols[i]["default"] is None
+
+ @testing.requires.temp_table_reflection
+ def test_get_temp_table_columns(self):
+ table_name = get_temp_table_name(
+ config, self.bind, "user_tmp_%s" % config.ident
+ )
+ user_tmp = self.tables[table_name]
+ insp = inspect(self.bind)
+ cols = insp.get_columns(table_name)
+ self.assert_(len(cols) > 0, len(cols))
+
+ for i, col in enumerate(user_tmp.columns):
+ eq_(col.name, cols[i]["name"])
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.view_column_reflection
+ @testing.requires.temporary_views
+ def test_get_temp_view_columns(self):
+ insp = inspect(self.bind)
+ cols = insp.get_columns("user_tmp_v")
+ eq_([col["name"] for col in cols], ["id", "name", "foo"])
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ @testing.requires.primary_key_constraint_reflection
+ def test_get_pk_constraint(self, connection, use_schema):
+ if use_schema:
+ schema = testing.config.test_schema
+ else:
+ schema = None
+
+ users, addresses = self.tables.users, self.tables.email_addresses
+ insp = inspect(connection)
+
+ users_cons = insp.get_pk_constraint(users.name, schema=schema)
+ users_pkeys = users_cons["constrained_columns"]
+ eq_(users_pkeys, ["user_id"])
+
+ addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
+ addr_pkeys = addr_cons["constrained_columns"]
+ eq_(addr_pkeys, ["address_id"])
+
+ with testing.requires.reflects_pk_names.fail_if():
+ eq_(addr_cons["name"], "email_ad_pk")
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ @testing.requires.foreign_key_constraint_reflection
+ def test_get_foreign_keys(self, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ users, addresses = (self.tables.users, self.tables.email_addresses)
+ insp = inspect(connection)
+ expected_schema = schema
+ # users
+
+ if testing.requires.self_referential_foreign_keys.enabled:
+ users_fkeys = insp.get_foreign_keys(users.name, schema=schema)
+ fkey1 = users_fkeys[0]
+
+ with testing.requires.named_constraints.fail_if():
+ eq_(fkey1["name"], "user_id_fk")
+
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ if testing.requires.self_referential_foreign_keys.enabled:
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
+
+ # addresses
+ addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
+ fkey1 = addr_fkeys[0]
+
+ with testing.requires.implicitly_named_constraints.fail_if():
+ self.assert_(fkey1["name"] is not None)
+
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ eq_(fkey1["constrained_columns"], ["remote_user_id"])
+
+ @testing.requires.cross_schema_fk_reflection
+ @testing.requires.schemas
+ def test_get_inter_schema_foreign_keys(self):
+ local_table, remote_table, remote_table_2 = self.tables(
+ "%s.local_table" % self.bind.dialect.default_schema_name,
+ "%s.remote_table" % testing.config.test_schema,
+ "%s.remote_table_2" % testing.config.test_schema,
+ )
+
+ insp = inspect(self.bind)
+
+ local_fkeys = insp.get_foreign_keys(local_table.name)
+ eq_(len(local_fkeys), 1)
+
+ fkey1 = local_fkeys[0]
+ eq_(fkey1["referred_schema"], testing.config.test_schema)
+ eq_(fkey1["referred_table"], remote_table_2.name)
+ eq_(fkey1["referred_columns"], ["id"])
+ eq_(fkey1["constrained_columns"], ["remote_id"])
+
+ remote_fkeys = insp.get_foreign_keys(
+ remote_table.name, schema=testing.config.test_schema
+ )
+ eq_(len(remote_fkeys), 1)
+
+ fkey2 = remote_fkeys[0]
+
+ assert fkey2["referred_schema"] in (
+ None,
+ self.bind.dialect.default_schema_name,
+ )
+ eq_(fkey2["referred_table"], local_table.name)
+ eq_(fkey2["referred_columns"], ["id"])
+ eq_(fkey2["constrained_columns"], ["local_id"])
+
+ def _assert_insp_indexes(self, indexes, expected_indexes):
+ index_names = [d["name"] for d in indexes]
+ for e_index in expected_indexes:
+ assert e_index["name"] in index_names
+ index = indexes[index_names.index(e_index["name"])]
+ for key in e_index:
+ eq_(e_index[key], index[key])
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ def test_get_indexes(self, connection, use_schema):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ # The database may decide to create indexes for foreign keys, etc.
+ # so there may be more indexes than expected.
+ insp = inspect(self.bind)
+ indexes = insp.get_indexes("users", schema=schema)
+ expected_indexes = [
+ {
+ "unique": False,
+ "column_names": ["test1", "test2"],
+ "name": "users_t_idx",
+ },
+ {
+ "unique": False,
+ "column_names": ["user_id", "test2", "test1"],
+ "name": "users_all_idx",
+ },
+ ]
+ self._assert_insp_indexes(indexes, expected_indexes)
+
+ @testing.combinations(
+ ("noncol_idx_test_nopk", "noncol_idx_nopk"),
+ ("noncol_idx_test_pk", "noncol_idx_pk"),
+ argnames="tname,ixname",
+ )
+ @testing.requires.index_reflection
+ @testing.requires.indexes_with_ascdesc
+ def test_get_noncol_index(self, connection, tname, ixname):
+ insp = inspect(connection)
+ indexes = insp.get_indexes(tname)
+
+ # reflecting an index that has "x DESC" in it as the column.
+ # the DB may or may not give us "x", but make sure we get the index
+ # back, it has a name, it's connected to the table.
+ expected_indexes = [{"unique": False, "name": ixname}]
+ self._assert_insp_indexes(indexes, expected_indexes)
+
+ t = Table(tname, MetaData(), autoload_with=connection)
+ eq_(len(t.indexes), 1)
+ is_(list(t.indexes)[0].table, t)
+ eq_(list(t.indexes)[0].name, ixname)
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.unique_constraint_reflection
+ def test_get_temp_table_unique_constraints(self):
+ insp = inspect(self.bind)
+ reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident)
+ for refl in reflected:
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ refl.pop("duplicates_index", None)
+ eq_(
+ reflected,
+ [
+ {
+ "column_names": ["name"],
+ "name": "user_tmp_uq_%s" % config.ident,
+ }
+ ],
+ )
+
+ @testing.requires.temp_table_reflect_indexes
+ def test_get_temp_table_indexes(self):
+ insp = inspect(self.bind)
+ table_name = get_temp_table_name(
+ config, config.db, "user_tmp_%s" % config.ident
+ )
+ indexes = insp.get_indexes(table_name)
+ for ind in indexes:
+ ind.pop("dialect_options", None)
+ expected = [
+ {"unique": False, "column_names": ["foo"], "name": "user_tmp_ix"}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(
+ [idx for idx in indexes if idx["name"] == "user_tmp_ix"],
+ expected,
+ )
+
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ @testing.requires.unique_constraint_reflection
+ def test_get_unique_constraints(self, metadata, connection, use_schema):
+ # SQLite dialect needs to parse the names of the constraints
+ # separately from what it gets from PRAGMA index_list(), and
+ # then matches them up. so same set of column_names in two
+ # constraints will confuse it. Perhaps we should no longer
+ # bother with index_list() here since we have the whole
+ # CREATE TABLE?
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ uniques = sorted(
+ [
+ {"name": "unique_a", "column_names": ["a"]},
+ {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]},
+ {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]},
+ {"name": "unique_asc_key", "column_names": ["asc", "key"]},
+ {"name": "i.have.dots", "column_names": ["b"]},
+ {"name": "i have spaces", "column_names": ["c"]},
+ ],
+ key=operator.itemgetter("name"),
+ )
+ table = Table(
+ "testtbl",
+ metadata,
+ Column("a", sa.String(20)),
+ Column("b", sa.String(30)),
+ Column("c", sa.Integer),
+ # reserved identifiers
+ Column("asc", sa.String(30)),
+ Column("key", sa.String(30)),
+ schema=schema,
+ )
+ for uc in uniques:
+ table.append_constraint(
+ sa.UniqueConstraint(*uc["column_names"], name=uc["name"])
+ )
+ table.create(connection)
+
+ inspector = inspect(connection)
+ reflected = sorted(
+ inspector.get_unique_constraints("testtbl", schema=schema),
+ key=operator.itemgetter("name"),
+ )
+
+ names_that_duplicate_index = set()
+
+ for orig, refl in zip(uniques, reflected):
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ dupe = refl.pop("duplicates_index", None)
+ if dupe:
+ names_that_duplicate_index.add(dupe)
+ eq_(orig, refl)
+
+ reflected_metadata = MetaData()
+ reflected = Table(
+ "testtbl",
+ reflected_metadata,
+ autoload_with=connection,
+ schema=schema,
+ )
+
+ # test "deduplicates for index" logic. MySQL and Oracle
+ # "unique constraints" are actually unique indexes (with possible
+ # exception of a unique that is a dupe of another one in the case
+ # of Oracle). make sure # they aren't duplicated.
+ idx_names = set([idx.name for idx in reflected.indexes])
+ uq_names = set(
+ [
+ uq.name
+ for uq in reflected.constraints
+ if isinstance(uq, sa.UniqueConstraint)
+ ]
+ ).difference(["unique_c_a_b"])
+
+ assert not idx_names.intersection(uq_names)
+ if names_that_duplicate_index:
+ eq_(names_that_duplicate_index, idx_names)
+ eq_(uq_names, set())
+
+ @testing.requires.view_reflection
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ def test_get_view_definition(self, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ view_name1 = "users_v"
+ view_name2 = "email_addresses_v"
+ insp = inspect(connection)
+ v1 = insp.get_view_definition(view_name1, schema=schema)
+ self.assert_(v1)
+ v2 = insp.get_view_definition(view_name2, schema=schema)
+ self.assert_(v2)
+
+ # why is this here if it's PG specific ?
+ @testing.combinations(
+ ("users", False),
+ ("users", True, testing.requires.schemas),
+ argnames="table_name,use_schema",
+ )
+ @testing.only_on("postgresql", "PG specific feature")
+ def test_get_table_oid(self, connection, table_name, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ insp = inspect(connection)
+ oid = insp.get_table_oid(table_name, schema)
+ self.assert_(isinstance(oid, int))
+
+ @testing.requires.table_reflection
+ def test_autoincrement_col(self):
+ """test that 'autoincrement' is reflected according to sqla's policy.
+
+ Don't mark this test as unsupported for any backend !
+
+ (technically it fails with MySQL InnoDB since "id" comes before "id2")
+
+ A backend is better off not returning "autoincrement" at all,
+ instead of potentially returning "False" for an auto-incrementing
+ primary key column.
+
+ """
+
+ insp = inspect(self.bind)
+
+ for tname, cname in [
+ ("users", "user_id"),
+ ("email_addresses", "address_id"),
+ ("dingalings", "dingaling_id"),
+ ]:
+ cols = insp.get_columns(tname)
+ id_ = {c["name"]: c for c in cols}[cname]
+ assert id_.get("autoincrement", True)
+
+
+class TableNoColumnsTest(fixtures.TestBase):
+ __requires__ = ("reflect_tables_no_columns",)
+ __backend__ = True
+
+ @testing.fixture
+ def table_no_columns(self, connection, metadata):
+ Table("empty", metadata)
+ metadata.create_all(connection)
+
+ @testing.fixture
+ def view_no_columns(self, connection, metadata):
+ Table("empty", metadata)
+ metadata.create_all(connection)
+
+ Table("empty", metadata)
+ event.listen(
+ metadata,
+ "after_create",
+ DDL("CREATE VIEW empty_v AS SELECT * FROM empty"),
+ )
+
+ # for transactional DDL the transaction is rolled back before this
+ # drop statement is invoked
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW IF EXISTS empty_v")
+ )
+ metadata.create_all(connection)
+
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_table_no_columns(self, connection, table_no_columns):
+ t2 = Table("empty", MetaData(), autoload_with=connection)
+ eq_(list(t2.c), [])
+
+ @testing.requires.reflect_tables_no_columns
+ def test_get_columns_table_no_columns(self, connection, table_no_columns):
+ eq_(inspect(connection).get_columns("empty"), [])
+
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_incl_table_no_columns(self, connection, table_no_columns):
+ m = MetaData()
+ m.reflect(connection)
+ assert set(m.tables).intersection(["empty"])
+
+ @testing.requires.views
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_view_no_columns(self, connection, view_no_columns):
+ t2 = Table("empty_v", MetaData(), autoload_with=connection)
+ eq_(list(t2.c), [])
+
+ @testing.requires.views
+ @testing.requires.reflect_tables_no_columns
+ def test_get_columns_view_no_columns(self, connection, view_no_columns):
+ eq_(inspect(connection).get_columns("empty_v"), [])
+
+
+class ComponentReflectionTestExtra(fixtures.TestBase):
+
+ __backend__ = True
+
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ @testing.requires.check_constraint_reflection
+ def test_get_check_constraints(self, metadata, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ Table(
+ "sa_cc",
+ metadata,
+ Column("a", Integer()),
+ sa.CheckConstraint("a > 1 AND a < 5", name="cc1"),
+ sa.CheckConstraint(
+ "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing"
+ ),
+ schema=schema,
+ )
+
+ metadata.create_all(connection)
+
+ inspector = inspect(connection)
+ reflected = sorted(
+ inspector.get_check_constraints("sa_cc", schema=schema),
+ key=operator.itemgetter("name"),
+ )
+
+ # trying to minimize effect of quoting, parenthesis, etc.
+ # may need to add more to this as new dialects get CHECK
+ # constraint reflection support
+ def normalize(sqltext):
+ return " ".join(
+ re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)
+ )
+
+ reflected = [
+ {"name": item["name"], "sqltext": normalize(item["sqltext"])}
+ for item in reflected
+ ]
+ eq_(
+ reflected,
+ [
+ {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"},
+ {"name": "cc1", "sqltext": "a > 1 and a < 5"},
+ ],
+ )
+
+ @testing.requires.indexes_with_expressions
+ def test_reflect_expression_based_indexes(self, metadata, connection):
+ t = Table(
+ "t",
+ metadata,
+ Column("x", String(30)),
+ Column("y", String(30)),
+ )
+
+ Index("t_idx", func.lower(t.c.x), func.lower(t.c.y))
+
+ Index("t_idx_2", t.c.x)
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ expected = [
+ {"name": "t_idx_2", "column_names": ["x"], "unique": False}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ expected[0]["dialect_options"] = {
+ "%s_include" % connection.engine.name: []
+ }
+
+ with expect_warnings(
+ "Skipped unsupported reflection of expression-based index t_idx"
+ ):
+ eq_(
+ insp.get_indexes("t"),
+ expected,
+ )
+
+ @testing.requires.index_reflects_included_columns
+ def test_reflect_covering_index(self, metadata, connection):
+ t = Table(
+ "t",
+ metadata,
+ Column("x", String(30)),
+ Column("y", String(30)),
+ )
+ idx = Index("t_idx", t.c.x)
+ idx.dialect_options[connection.engine.name]["include"] = ["y"]
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ eq_(
+ insp.get_indexes("t"),
+ [
+ {
+ "name": "t_idx",
+ "column_names": ["x"],
+ "include_columns": ["y"],
+ "unique": False,
+ "dialect_options": {
+ "%s_include" % connection.engine.name: ["y"]
+ },
+ }
+ ],
+ )
+
+ t2 = Table("t", MetaData(), autoload_with=connection)
+ eq_(
+ list(t2.indexes)[0].dialect_options[connection.engine.name][
+ "include"
+ ],
+ ["y"],
+ )
+
+ def _type_round_trip(self, connection, metadata, *types):
+ t = Table(
+ "t",
+ metadata,
+ *[Column("t%d" % i, type_) for i, type_ in enumerate(types)]
+ )
+ t.create(connection)
+
+ return [c["type"] for c in inspect(connection).get_columns("t")]
+
+ @testing.requires.table_reflection
+ def test_numeric_reflection(self, connection, metadata):
+ for typ in self._type_round_trip(
+ connection, metadata, sql_types.Numeric(18, 5)
+ ):
+ assert isinstance(typ, sql_types.Numeric)
+ eq_(typ.precision, 18)
+ eq_(typ.scale, 5)
+
+ @testing.requires.table_reflection
+ def test_varchar_reflection(self, connection, metadata):
+ typ = self._type_round_trip(
+ connection, metadata, sql_types.String(52)
+ )[0]
+ assert isinstance(typ, sql_types.String)
+ eq_(typ.length, 52)
+
+ @testing.requires.table_reflection
+ def test_nullable_reflection(self, connection, metadata):
+ t = Table(
+ "t",
+ metadata,
+ Column("a", Integer, nullable=True),
+ Column("b", Integer, nullable=False),
+ )
+ t.create(connection)
+ eq_(
+ dict(
+ (col["name"], col["nullable"])
+ for col in inspect(connection).get_columns("t")
+ ),
+ {"a": True, "b": False},
+ )
+
+ @testing.combinations(
+ (
+ None,
+ "CASCADE",
+ None,
+ testing.requires.foreign_key_constraint_option_reflection_ondelete,
+ ),
+ (
+ None,
+ None,
+ "SET NULL",
+ testing.requires.foreign_key_constraint_option_reflection_onupdate,
+ ),
+ (
+ {},
+ None,
+ "NO ACTION",
+ testing.requires.foreign_key_constraint_option_reflection_onupdate,
+ ),
+ (
+ {},
+ "NO ACTION",
+ None,
+ testing.requires.fk_constraint_option_reflection_ondelete_noaction,
+ ),
+ (
+ None,
+ None,
+ "RESTRICT",
+ testing.requires.fk_constraint_option_reflection_onupdate_restrict,
+ ),
+ (
+ None,
+ "RESTRICT",
+ None,
+ testing.requires.fk_constraint_option_reflection_ondelete_restrict,
+ ),
+ argnames="expected,ondelete,onupdate",
+ )
+ def test_get_foreign_key_options(
+ self, connection, metadata, expected, ondelete, onupdate
+ ):
+ options = {}
+ if ondelete:
+ options["ondelete"] = ondelete
+ if onupdate:
+ options["onupdate"] = onupdate
+
+ if expected is None:
+ expected = options
+
+ Table(
+ "x",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")),
+ Column("test", String(10)),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "user",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ sa.ForeignKeyConstraint(
+ ["tid"], ["table.id"], name="myfk", **options
+ ),
+ test_needs_fk=True,
+ )
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ # test 'options' is always present for a backend
+ # that can reflect these, since alembic looks for this
+ opts = insp.get_foreign_keys("table")[0]["options"]
+
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), {})
+
+ opts = insp.get_foreign_keys("user")[0]["options"]
+ eq_(opts, expected)
+ # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected)
+
+
+class NormalizedNameTest(fixtures.TablesTest):
+ __requires__ = ("denormalized_names",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ quoted_name("t1", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ )
+ Table(
+ quoted_name("t2", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("t1id", ForeignKey("t1.id")),
+ )
+
+ def test_reflect_lowercase_forced_tables(self):
+
+ m2 = MetaData()
+ t2_ref = Table(
+ quoted_name("t2", quote=True), m2, autoload_with=config.db
+ )
+ t1_ref = m2.tables["t1"]
+ assert t2_ref.c.t1id.references(t1_ref.c.id)
+
+ m3 = MetaData()
+ m3.reflect(
+ config.db, only=lambda name, m: name.lower() in ("t1", "t2")
+ )
+ assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id)
+
+ def test_get_table_names(self):
+ tablenames = [
+ t
+ for t in inspect(config.db).get_table_names()
+ if t.lower() in ("t1", "t2")
+ ]
+
+ eq_(tablenames[0].upper(), tablenames[0].lower())
+ eq_(tablenames[1].upper(), tablenames[1].lower())
+
+
+class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest):
+ def test_computed_col_default_not_set(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_default_table")
+ col_data = {c["name"]: c for c in cols}
+ is_true("42" in col_data["with_default"]["default"])
+ is_(col_data["normal"]["default"], None)
+ is_(col_data["computed_col"]["default"], None)
+
+ def test_get_column_returns_computed(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_default_table")
+ data = {c["name"]: c for c in cols}
+ for key in ("id", "normal", "with_default"):
+ is_true("computed" not in data[key])
+ compData = data["computed_col"]
+ is_true("computed" in compData)
+ is_true("sqltext" in compData["computed"])
+ eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")
+ eq_(
+ "persisted" in compData["computed"],
+ testing.requires.computed_columns_reflect_persisted.enabled,
+ )
+ if testing.requires.computed_columns_reflect_persisted.enabled:
+ eq_(
+ compData["computed"]["persisted"],
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+
+ def check_column(self, data, column, sqltext, persisted):
+ is_true("computed" in data[column])
+ compData = data[column]["computed"]
+ eq_(self.normalize(compData["sqltext"]), sqltext)
+ if testing.requires.computed_columns_reflect_persisted.enabled:
+ is_true("persisted" in compData)
+ is_(compData["persisted"], persisted)
+
+ def test_get_column_returns_persisted(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_column_table")
+ data = {c["name"]: c for c in cols}
+
+ self.check_column(
+ data,
+ "computed_no_flag",
+ "normal+42",
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+ if testing.requires.computed_columns_virtual.enabled:
+ self.check_column(
+ data,
+ "computed_virtual",
+ "normal+2",
+ False,
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ self.check_column(
+ data,
+ "computed_stored",
+ "normal-42",
+ True,
+ )
+
+ @testing.requires.schemas
+ def test_get_column_returns_persisted_with_schema(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns(
+ "computed_column_table", schema=config.test_schema
+ )
+ data = {c["name"]: c for c in cols}
+
+ self.check_column(
+ data,
+ "computed_no_flag",
+ "normal/42",
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+ if testing.requires.computed_columns_virtual.enabled:
+ self.check_column(
+ data,
+ "computed_virtual",
+ "normal/2",
+ False,
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ self.check_column(
+ data,
+ "computed_stored",
+ "normal*42",
+ True,
+ )
+
+
+class IdentityReflectionTest(fixtures.TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("identity_columns", "table_reflection")
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "t1",
+ metadata,
+ Column("normal", Integer),
+ Column("id1", Integer, Identity()),
+ )
+ Table(
+ "t2",
+ metadata,
+ Column(
+ "id2",
+ Integer,
+ Identity(
+ always=True,
+ start=2,
+ increment=3,
+ minvalue=-2,
+ maxvalue=42,
+ cycle=True,
+ cache=4,
+ ),
+ ),
+ )
+ if testing.requires.schemas.enabled:
+ Table(
+ "t1",
+ metadata,
+ Column("normal", Integer),
+ Column("id1", Integer, Identity(always=True, start=20)),
+ schema=config.test_schema,
+ )
+
+ def check(self, value, exp, approx):
+ if testing.requires.identity_columns_standard.enabled:
+ common_keys = (
+ "always",
+ "start",
+ "increment",
+ "minvalue",
+ "maxvalue",
+ "cycle",
+ "cache",
+ )
+ for k in list(value):
+ if k not in common_keys:
+ value.pop(k)
+ if approx:
+ eq_(len(value), len(exp))
+ for k in value:
+ if k == "minvalue":
+ is_true(value[k] <= exp[k])
+ elif k in {"maxvalue", "cache"}:
+ is_true(value[k] >= exp[k])
+ else:
+ eq_(value[k], exp[k], k)
+ else:
+ eq_(value, exp)
+ else:
+ eq_(value["start"], exp["start"])
+ eq_(value["increment"], exp["increment"])
+
+ def test_reflect_identity(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("t1") + insp.get_columns("t2")
+ for col in cols:
+ if col["name"] == "normal":
+ is_false("identity" in col)
+ elif col["name"] == "id1":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=False,
+ start=1,
+ increment=1,
+ minvalue=1,
+ maxvalue=2147483647,
+ cycle=False,
+ cache=1,
+ ),
+ approx=True,
+ )
+ elif col["name"] == "id2":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=True,
+ start=2,
+ increment=3,
+ minvalue=-2,
+ maxvalue=42,
+ cycle=True,
+ cache=4,
+ ),
+ approx=False,
+ )
+
+ @testing.requires.schemas
+ def test_reflect_identity_schema(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("t1", schema=config.test_schema)
+ for col in cols:
+ if col["name"] == "normal":
+ is_false("identity" in col)
+ elif col["name"] == "id1":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=True,
+ start=20,
+ increment=1,
+ minvalue=1,
+ maxvalue=2147483647,
+ cycle=False,
+ cache=1,
+ ),
+ approx=True,
+ )
+
+
+class CompositeKeyReflectionTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ tb1 = Table(
+ "tb1",
+ metadata,
+ Column("id", Integer),
+ Column("attr", Integer),
+ Column("name", sql_types.VARCHAR(20)),
+ sa.PrimaryKeyConstraint("name", "id", "attr", name="pk_tb1"),
+ schema=None,
+ test_needs_fk=True,
+ )
+ Table(
+ "tb2",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("pid", Integer),
+ Column("pattr", Integer),
+ Column("pname", sql_types.VARCHAR(20)),
+ sa.ForeignKeyConstraint(
+ ["pname", "pid", "pattr"],
+ [tb1.c.name, tb1.c.id, tb1.c.attr],
+ name="fk_tb1_name_id_attr",
+ ),
+ schema=None,
+ test_needs_fk=True,
+ )
+
+ @testing.requires.primary_key_constraint_reflection
+ def test_pk_column_order(self):
+ # test for issue #5661
+ insp = inspect(self.bind)
+ primary_key = insp.get_pk_constraint(self.tables.tb1.name)
+ eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"])
+
+ @testing.requires.foreign_key_constraint_reflection
+ def test_fk_column_order(self):
+ # test for issue #5661
+ insp = inspect(self.bind)
+ foreign_keys = insp.get_foreign_keys(self.tables.tb2.name)
+ eq_(len(foreign_keys), 1)
+ fkey1 = foreign_keys[0]
+ eq_(fkey1.get("referred_columns"), ["name", "id", "attr"])
+ eq_(fkey1.get("constrained_columns"), ["pname", "pid", "pattr"])
+
+
+__all__ = (
+ "ComponentReflectionTest",
+ "ComponentReflectionTestExtra",
+ "TableNoColumnsTest",
+ "QuotedNameArgumentTest",
+ "HasTableTest",
+ "HasIndexTest",
+ "NormalizedNameTest",
+ "ComputedReflectionTest",
+ "IdentityReflectionTest",
+ "CompositeKeyReflectionTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
new file mode 100644
index 0000000..c41a550
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -0,0 +1,426 @@
+import datetime
+
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import DateTime
+from ... import func
+from ... import Integer
+from ... import select
+from ... import sql
+from ... import String
+from ... import testing
+from ... import text
+from ... import util
+
+
+class RowFetchTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Table(
+ "has_dates",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("today", DateTime),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.plain_pk.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ connection.execute(
+ cls.tables.has_dates.insert(),
+ [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
+ )
+
+ def test_via_attr(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row.id, 1)
+ eq_(row.data, "d1")
+
+ def test_via_string(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row._mapping["id"], 1)
+ eq_(row._mapping["data"], "d1")
+
+ def test_via_int(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row[0], 1)
+ eq_(row[1], "d1")
+
+ def test_via_col_object(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row._mapping[self.tables.plain_pk.c.id], 1)
+ eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
+
+ @requirements.duplicate_names_in_cursor_description
+ def test_row_with_dupe_names(self, connection):
+ result = connection.execute(
+ select(
+ self.tables.plain_pk.c.data,
+ self.tables.plain_pk.c.data.label("data"),
+ ).order_by(self.tables.plain_pk.c.id)
+ )
+ row = result.first()
+ eq_(result.keys(), ["data", "data"])
+ eq_(row, ("d1", "d1"))
+
+ def test_row_w_scalar_select(self, connection):
+ """test that a scalar select as a column is returned as such
+ and that type conversion works OK.
+
+ (this is half a SQLAlchemy Core test and half to catch database
+ backends that may have unusual behavior with scalar selects.)
+
+ """
+ datetable = self.tables.has_dates
+ s = select(datetable.alias("x").c.today).scalar_subquery()
+ s2 = select(datetable.c.id, s.label("somelabel"))
+ row = connection.execute(s2).first()
+
+ eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
+
+
+class PercentSchemaNamesTest(fixtures.TablesTest):
+ """tests using percent signs, spaces in table and column names.
+
+ This didn't work for PostgreSQL / MySQL drivers for a long time
+ but is now supported.
+
+ """
+
+ __requires__ = ("percent_schema_names",)
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ cls.tables.percent_table = Table(
+ "percent%table",
+ metadata,
+ Column("percent%", Integer),
+ Column("spaces % more spaces", Integer),
+ )
+ cls.tables.lightweight_percent_table = sql.table(
+ "percent%table",
+ sql.column("percent%"),
+ sql.column("spaces % more spaces"),
+ )
+
+ def test_single_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ for params in [
+ {"percent%": 5, "spaces % more spaces": 12},
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ]:
+ connection.execute(percent_table.insert(), params)
+ self._assert_table(connection)
+
+ def test_executemany_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ connection.execute(
+ percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
+ )
+ connection.execute(
+ percent_table.insert(),
+ [
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ],
+ )
+ self._assert_table(connection)
+
+ def _assert_table(self, conn):
+ percent_table = self.tables.percent_table
+ lightweight_percent_table = self.tables.lightweight_percent_table
+
+ for table in (
+ percent_table,
+ percent_table.alias(),
+ lightweight_percent_table,
+ lightweight_percent_table.alias(),
+ ):
+ eq_(
+ list(
+ conn.execute(table.select().order_by(table.c["percent%"]))
+ ),
+ [(5, 12), (7, 11), (9, 10), (11, 9)],
+ )
+
+ eq_(
+ list(
+ conn.execute(
+ table.select()
+ .where(table.c["spaces % more spaces"].in_([9, 10]))
+ .order_by(table.c["percent%"])
+ )
+ ),
+ [(9, 10), (11, 9)],
+ )
+
+ row = conn.execute(
+ table.select().order_by(table.c["percent%"])
+ ).first()
+ eq_(row._mapping["percent%"], 5)
+ eq_(row._mapping["spaces % more spaces"], 12)
+
+ eq_(row._mapping[table.c["percent%"]], 5)
+ eq_(row._mapping[table.c["spaces % more spaces"]], 12)
+
+ conn.execute(
+ percent_table.update().values(
+ {percent_table.c["spaces % more spaces"]: 15}
+ )
+ )
+
+ eq_(
+ list(
+ conn.execute(
+ percent_table.select().order_by(
+ percent_table.c["percent%"]
+ )
+ )
+ ),
+ [(5, 15), (7, 15), (9, 15), (11, 15)],
+ )
+
+
+class ServerSideCursorsTest(
+ fixtures.TestBase, testing.AssertsExecutionResults
+):
+
+ __requires__ = ("server_side_cursors",)
+
+ __backend__ = True
+
+ def _is_server_side(self, cursor):
+ # TODO: this is a huge issue as it prevents these tests from being
+ # usable by third party dialects.
+ if self.engine.dialect.driver == "psycopg2":
+ return bool(cursor.name)
+ elif self.engine.dialect.driver == "pymysql":
+ sscursor = __import__("pymysql.cursors").cursors.SSCursor
+ return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver in ("aiomysql", "asyncmy"):
+ return cursor.server_side
+ elif self.engine.dialect.driver == "mysqldb":
+ sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
+ return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver == "mariadbconnector":
+ return not cursor.buffered
+ elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
+ return cursor.server_side
+ elif self.engine.dialect.driver == "pg8000":
+ return getattr(cursor, "server_side", False)
+ else:
+ return False
+
+ def _fixture(self, server_side_cursors):
+ if server_side_cursors:
+ with testing.expect_deprecated(
+ "The create_engine.server_side_cursors parameter is "
+ "deprecated and will be removed in a future release. "
+ "Please use the Connection.execution_options.stream_results "
+ "parameter."
+ ):
+ self.engine = engines.testing_engine(
+ options={"server_side_cursors": server_side_cursors}
+ )
+ else:
+ self.engine = engines.testing_engine(
+ options={"server_side_cursors": server_side_cursors}
+ )
+ return self.engine
+
+ @testing.combinations(
+ ("global_string", True, "select 1", True),
+ ("global_text", True, text("select 1"), True),
+ ("global_expr", True, select(1), True),
+ ("global_off_explicit", False, text("select 1"), False),
+ (
+ "stmt_option",
+ False,
+ select(1).execution_options(stream_results=True),
+ True,
+ ),
+ (
+ "stmt_option_disabled",
+ True,
+ select(1).execution_options(stream_results=False),
+ False,
+ ),
+ ("for_update_expr", True, select(1).with_for_update(), True),
+ # TODO: need a real requirement for this, or dont use this test
+ (
+ "for_update_string",
+ True,
+ "SELECT 1 FOR UPDATE",
+ True,
+ testing.skip_if("sqlite"),
+ ),
+ ("text_no_ss", False, text("select 42"), False),
+ (
+ "text_ss_option",
+ False,
+ text("select 42").execution_options(stream_results=True),
+ True,
+ ),
+ id_="iaaa",
+ argnames="engine_ss_arg, statement, cursor_ss_status",
+ )
+ def test_ss_cursor_status(
+ self, engine_ss_arg, statement, cursor_ss_status
+ ):
+ engine = self._fixture(engine_ss_arg)
+ with engine.begin() as conn:
+ if isinstance(statement, util.string_types):
+ result = conn.exec_driver_sql(statement)
+ else:
+ result = conn.execute(statement)
+ eq_(self._is_server_side(result.cursor), cursor_ss_status)
+ result.close()
+
+ def test_conn_option(self):
+ engine = self._fixture(False)
+
+ with engine.connect() as conn:
+ # should be enabled for this one
+ result = conn.execution_options(
+ stream_results=True
+ ).exec_driver_sql("select 1")
+ assert self._is_server_side(result.cursor)
+
+ def test_stmt_enabled_conn_option_disabled(self):
+ engine = self._fixture(False)
+
+ s = select(1).execution_options(stream_results=True)
+
+ with engine.connect() as conn:
+ # not this one
+ result = conn.execution_options(stream_results=False).execute(s)
+ assert not self._is_server_side(result.cursor)
+
+ def test_aliases_and_ss(self):
+ engine = self._fixture(False)
+ s1 = (
+ select(sql.literal_column("1").label("x"))
+ .execution_options(stream_results=True)
+ .subquery()
+ )
+
+ # options don't propagate out when subquery is used as a FROM clause
+ with engine.begin() as conn:
+ result = conn.execute(s1.select())
+ assert not self._is_server_side(result.cursor)
+ result.close()
+
+ s2 = select(1).select_from(s1)
+ with engine.begin() as conn:
+ result = conn.execute(s2)
+ assert not self._is_server_side(result.cursor)
+ result.close()
+
+ def test_roundtrip_fetchall(self, metadata):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ with engine.begin() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(test_table.insert(), dict(data="data1"))
+ connection.execute(test_table.insert(), dict(data="data2"))
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2")],
+ )
+ connection.execute(
+ test_table.update()
+ .where(test_table.c.id == 2)
+ .values(data=test_table.c.data + " updated")
+ )
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2 updated")],
+ )
+ connection.execute(test_table.delete())
+ eq_(
+ connection.scalar(
+ select(func.count("*")).select_from(test_table)
+ ),
+ 0,
+ )
+
+ def test_roundtrip_fetchmany(self, metadata):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ with engine.begin() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(
+ test_table.insert(),
+ [dict(data="data%d" % i) for i in range(1, 20)],
+ )
+
+ result = connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ )
+
+ eq_(
+ result.fetchmany(5),
+ [(i, "data%d" % i) for i in range(1, 6)],
+ )
+ eq_(
+ result.fetchmany(10),
+ [(i, "data%d" % i) for i in range(6, 16)],
+ )
+ eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py
new file mode 100644
index 0000000..82e831f
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_rowcount.py
@@ -0,0 +1,165 @@
+from sqlalchemy import bindparam
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import text
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+
+
+class RowCountTest(fixtures.TablesTest):
+ """test rowcount functionality"""
+
+ __requires__ = ("sane_rowcount",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "employees",
+ metadata,
+ Column(
+ "employee_id",
+ Integer,
+ autoincrement=False,
+ primary_key=True,
+ ),
+ Column("name", String(50)),
+ Column("department", String(1)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ cls.data = data = [
+ ("Angela", "A"),
+ ("Andrew", "A"),
+ ("Anand", "A"),
+ ("Bob", "B"),
+ ("Bobette", "B"),
+ ("Buffy", "B"),
+ ("Charlie", "C"),
+ ("Cynthia", "C"),
+ ("Chris", "C"),
+ ]
+
+ employees_table = cls.tables.employees
+ connection.execute(
+ employees_table.insert(),
+ [
+ {"employee_id": i, "name": n, "department": d}
+ for i, (n, d) in enumerate(data)
+ ],
+ )
+
+ def test_basic(self, connection):
+ employees_table = self.tables.employees
+ s = select(
+ employees_table.c.name, employees_table.c.department
+ ).order_by(employees_table.c.employee_id)
+ rows = connection.execute(s).fetchall()
+
+ eq_(rows, self.data)
+
+ def test_update_rowcount1(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 3 rows changed
+ department = employees_table.c.department
+ r = connection.execute(
+ employees_table.update().where(department == "C"),
+ {"department": "Z"},
+ )
+ assert r.rowcount == 3
+
+ def test_update_rowcount2(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 0 rows changed
+ department = employees_table.c.department
+
+ r = connection.execute(
+ employees_table.update().where(department == "C"),
+ {"department": "C"},
+ )
+ eq_(r.rowcount, 3)
+
+ @testing.requires.sane_rowcount_w_returning
+ def test_update_rowcount_return_defaults(self, connection):
+ employees_table = self.tables.employees
+
+ department = employees_table.c.department
+ stmt = (
+ employees_table.update()
+ .where(department == "C")
+ .values(name=employees_table.c.department + "Z")
+ .return_defaults()
+ )
+
+ r = connection.execute(stmt)
+ eq_(r.rowcount, 3)
+
+ def test_raw_sql_rowcount(self, connection):
+ # test issue #3622, make sure eager rowcount is called for text
+ result = connection.exec_driver_sql(
+ "update employees set department='Z' where department='C'"
+ )
+ eq_(result.rowcount, 3)
+
+ def test_text_rowcount(self, connection):
+ # test issue #3622, make sure eager rowcount is called for text
+ result = connection.execute(
+ text("update employees set department='Z' " "where department='C'")
+ )
+ eq_(result.rowcount, 3)
+
+ def test_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 3 rows deleted
+ department = employees_table.c.department
+ r = connection.execute(
+ employees_table.delete().where(department == "C")
+ )
+ eq_(r.rowcount, 3)
+
+ @testing.requires.sane_multi_rowcount
+ def test_multi_update_rowcount(self, connection):
+ employees_table = self.tables.employees
+ stmt = (
+ employees_table.update()
+ .where(employees_table.c.name == bindparam("emp_name"))
+ .values(department="C")
+ )
+
+ r = connection.execute(
+ stmt,
+ [
+ {"emp_name": "Bob"},
+ {"emp_name": "Cynthia"},
+ {"emp_name": "nonexistent"},
+ ],
+ )
+
+ eq_(r.rowcount, 2)
+
+ @testing.requires.sane_multi_rowcount
+ def test_multi_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
+ stmt = employees_table.delete().where(
+ employees_table.c.name == bindparam("emp_name")
+ )
+
+ r = connection.execute(
+ stmt,
+ [
+ {"emp_name": "Bob"},
+ {"emp_name": "Cynthia"},
+ {"emp_name": "nonexistent"},
+ ],
+ )
+
+ eq_(r.rowcount, 2)
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
new file mode 100644
index 0000000..cb78fff
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -0,0 +1,1783 @@
+import itertools
+
+from .. import AssertsCompiledSQL
+from .. import AssertsExecutionResults
+from .. import config
+from .. import fixtures
+from ..assertions import assert_raises
+from ..assertions import eq_
+from ..assertions import in_
+from ..assertsql import CursorSQL
+from ..schema import Column
+from ..schema import Table
+from ... import bindparam
+from ... import case
+from ... import column
+from ... import Computed
+from ... import exists
+from ... import false
+from ... import ForeignKey
+from ... import func
+from ... import Identity
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import null
+from ... import select
+from ... import String
+from ... import table
+from ... import testing
+from ... import text
+from ... import true
+from ... import tuple_
+from ... import TupleType
+from ... import union
+from ... import util
+from ... import values
+from ...exc import DatabaseError
+from ...exc import ProgrammingError
+from ...util import collections_abc
+
+
+class CollateTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(100)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "collate data1"},
+ {"id": 2, "data": "collate data2"},
+ ],
+ )
+
+ def _assert_result(self, select, result):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select).fetchall(), result)
+
+ @testing.requires.order_by_collation
+ def test_collate_order_by(self):
+ collation = testing.requires.get_order_by_collation(testing.config)
+
+ self._assert_result(
+ select(self.tables.some_table).order_by(
+ self.tables.some_table.c.data.collate(collation).asc()
+ ),
+ [(1, "collate data1"), (2, "collate data2")],
+ )
+
+
+class OrderByLabelTest(fixtures.TablesTest):
+ """Test the dialect sends appropriate ORDER BY expressions when
+ labels are used.
+
+ This essentially exercises the "supports_simple_order_by_label"
+ setting.
+
+ """
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("q", String(50)),
+ Column("p", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
+ {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
+ {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
+ ],
+ )
+
+ def _assert_result(self, select, result):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select).fetchall(), result)
+
+ def test_plain(self):
+ table = self.tables.some_table
+ lx = table.c.x.label("lx")
+ self._assert_result(select(lx).order_by(lx), [(1,), (2,), (3,)])
+
+ def test_composed_int(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select(lx).order_by(lx), [(3,), (5,), (7,)])
+
+ def test_composed_multiple(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ ly = (func.lower(table.c.q) + table.c.p).label("ly")
+ self._assert_result(
+ select(lx, ly).order_by(lx, ly.desc()),
+ [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))],
+ )
+
+ def test_plain_desc(self):
+ table = self.tables.some_table
+ lx = table.c.x.label("lx")
+ self._assert_result(select(lx).order_by(lx.desc()), [(3,), (2,), (1,)])
+
+ def test_composed_int_desc(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select(lx).order_by(lx.desc()), [(7,), (5,), (3,)])
+
+ @testing.requires.group_by_complex_expression
+ def test_group_by_composed(self):
+ table = self.tables.some_table
+ expr = (table.c.x + table.c.y).label("lx")
+ stmt = (
+ select(func.count(table.c.id), expr).group_by(expr).order_by(expr)
+ )
+ self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)])
+
+
+class ValuesExpressionTest(fixtures.TestBase):
+ __requires__ = ("table_value_constructor",)
+
+ __backend__ = True
+
+ def test_tuples(self, connection):
+ value_expr = values(
+ column("id", Integer), column("name", String), name="my_values"
+ ).data([(1, "name1"), (2, "name2"), (3, "name3")])
+
+ eq_(
+ connection.execute(select(value_expr)).all(),
+ [(1, "name1"), (2, "name2"), (3, "name3")],
+ )
+
+
+class FetchLimitOffsetTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ {"id": 5, "x": 4, "y": 6},
+ ],
+ )
+
+ def _assert_result(
+ self, connection, select, result, params=(), set_=False
+ ):
+ if set_:
+ query_res = connection.execute(select, params).fetchall()
+ eq_(len(query_res), len(result))
+ eq_(set(query_res), set(result))
+
+ else:
+ eq_(connection.execute(select, params).fetchall(), result)
+
+ def _assert_result_str(self, select, result, params=()):
+ conn = config.db.connect(close_with_result=True)
+ eq_(conn.exec_driver_sql(select, params).fetchall(), result)
+
+ def test_simple_limit(self, connection):
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id)
+ self._assert_result(
+ connection,
+ stmt.limit(2),
+ [(1, 1, 2), (2, 2, 3)],
+ )
+ self._assert_result(
+ connection,
+ stmt.limit(3),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ def test_limit_render_multiple_times(self, connection):
+ table = self.tables.some_table
+ stmt = select(table.c.id).limit(1).scalar_subquery()
+
+ u = union(select(stmt), select(stmt)).subquery().select()
+
+ self._assert_result(
+ connection,
+ u,
+ [
+ (1,),
+ ],
+ )
+
+ @testing.requires.fetch_first
+ def test_simple_fetch(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(2),
+ [(1, 1, 2), (2, 2, 3)],
+ )
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(3),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.offset
+ def test_simple_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(2),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(3),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.combinations(
+ ([(2, 0), (2, 1), (3, 2)]),
+ ([(2, 1), (2, 0), (3, 2)]),
+ ([(3, 1), (2, 1), (3, 1)]),
+ argnames="cases",
+ )
+ @testing.requires.offset
+ def test_simple_limit_offset(self, connection, cases):
+ table = self.tables.some_table
+ connection = connection.execution_options(compiled_cache={})
+
+ assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)]
+
+ for limit, offset in cases:
+ expected = assert_data[offset : offset + limit]
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(limit).offset(offset),
+ expected,
+ )
+
+ @testing.requires.fetch_first
+ def test_simple_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(2).offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(3).offset(2),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.fetch_no_order_by
+ def test_fetch_offset_no_order(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).fetch(10),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.offset
+ def test_simple_offset_zero(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(0),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(1),
+ [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.offset
+ def test_limit_offset_nobinds(self):
+ """test that 'literal binds' mode works - no bound params."""
+
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id).limit(2).offset(1)
+ sql = stmt.compile(
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
+ sql = str(sql)
+
+ self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)])
+
+ @testing.requires.fetch_first
+ def test_fetch_offset_nobinds(self):
+ """test that 'literal binds' mode works - no bound params."""
+
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id).fetch(2).offset(1)
+ sql = stmt.compile(
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
+ sql = str(sql)
+
+ self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)])
+
+ @testing.requires.bound_limit_offset
+ def test_bound_limit(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(bindparam("l")),
+ [(1, 1, 2), (2, 2, 3)],
+ params={"l": 2},
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(bindparam("l")),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ params={"l": 3},
+ )
+
+ @testing.requires.bound_limit_offset
+ def test_bound_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"o": 2},
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"o": 1},
+ )
+
+ @testing.requires.bound_limit_offset
+ def test_bound_limit_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4)],
+ params={"l": 2, "o": 1},
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"l": 3, "o": 2},
+ )
+
+ @testing.requires.fetch_first
+ def test_bound_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(bindparam("f"))
+ .offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4)],
+ params={"f": 2, "o": 1},
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(bindparam("f"))
+ .offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"f": 3, "o": 2},
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .offset(literal_column("1") + literal_column("2")),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("2")),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ @testing.requires.fetch_first
+ @testing.requires.fetch_expression
+ def test_expr_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(literal_column("1") + literal_column("1"))
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_simple_limit_expr_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(2)
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(3)
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit_simple_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(2),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.fetch_ties
+ def test_simple_fetch_ties(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.x.desc()).fetch(1, with_ties=True),
+ [(4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.x.desc()).fetch(3, with_ties=True),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_ties(self, connection):
+ table = self.tables.some_table
+ fa = connection.execute(
+ select(table)
+ .order_by(table.c.x)
+ .fetch(2, with_ties=True)
+ .offset(2)
+ ).fetchall()
+ eq_(fa[0], (3, 3, 4))
+ eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_ties_exact_number(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x)
+ .fetch(2, with_ties=True)
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x)
+ .fetch(3, with_ties=True)
+ .offset(3),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.fetch_percent
+ def test_simple_fetch_percent(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(20, percent=True),
+ [(1, 1, 2)],
+ )
+
+ @testing.requires.fetch_percent
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_percent(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(40, percent=True)
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_percent
+ def test_simple_fetch_percent_ties(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x.desc())
+ .fetch(20, percent=True, with_ties=True),
+ [(4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_percent
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_percent_ties(self, connection):
+ table = self.tables.some_table
+ fa = connection.execute(
+ select(table)
+ .order_by(table.c.x)
+ .fetch(40, percent=True, with_ties=True)
+ .offset(2)
+ ).fetchall()
+ eq_(fa[0], (3, 3, 4))
+ eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
+
+
+class JoinTest(fixtures.TablesTest):
+ __backend__ = True
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("a", metadata, Column("id", Integer, primary_key=True))
+ Table(
+ "b",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("a_id", ForeignKey("a.id"), nullable=False),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.a.insert(),
+ [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}],
+ )
+
+ connection.execute(
+ cls.tables.b.insert(),
+ [
+ {"id": 1, "a_id": 1},
+ {"id": 2, "a_id": 1},
+ {"id": 4, "a_id": 2},
+ {"id": 5, "a_id": 3},
+ ],
+ )
+
+ def test_inner_join_fk(self):
+ a, b = self.tables("a", "b")
+
+ stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id)
+
+ self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)])
+
+ def test_inner_join_true(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.join(b, true()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(
+ stmt,
+ [
+ (a, b, c)
+ for (a,), (b, c) in itertools.product(
+ [(1,), (2,), (3,), (4,), (5,)],
+ [(1, 1), (2, 1), (4, 2), (5, 3)],
+ )
+ ],
+ )
+
+ def test_inner_join_false(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.join(b, false()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(stmt, [])
+
+ def test_outer_join_false(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.outerjoin(b, false()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(
+ stmt,
+ [
+ (1, None, None),
+ (2, None, None),
+ (3, None, None),
+ (4, None, None),
+ (5, None, None),
+ ],
+ )
+
+ def test_outer_join_fk(self):
+ a, b = self.tables("a", "b")
+
+ stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id)
+
+ self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)])
+
+
+class CompoundSelectTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ],
+ )
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_plain_union(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_select_from_plain_union(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2).alias().select()
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.order_by_col_from_union
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_wo_limit_offset
+ def test_order_by_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_distinct_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).distinct()
+ s2 = select(table).where(table.c.id == 3).distinct()
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_in_unions_from_alias(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ # this necessarily has double parens
+ u1 = union(s1, s2).alias()
+ self._assert_result(
+ u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_limit_offset_aliased_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = (
+ select(table)
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select(table)
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+
+class PostCompileParamsTest(
+ AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest
+):
+ __backend__ = True
+
+ __requires__ = ("standard_cursor_sql",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "z": "z1"},
+ {"id": 2, "x": 2, "y": 3, "z": "z2"},
+ {"id": 3, "x": 3, "y": 4, "z": "z3"},
+ {"id": 4, "x": 4, "y": 5, "z": "z4"},
+ ],
+ )
+
+ def test_compile(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", literal_execute=True)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT some_table.id FROM some_table "
+ "WHERE some_table.x = __[POSTCOMPILE_q]",
+ {},
+ )
+
+ def test_compile_literal_binds(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", 10, literal_execute=True)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT some_table.id FROM some_table WHERE some_table.x = 10",
+ {},
+ literal_binds=True,
+ )
+
+ def test_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", literal_execute=True)
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=10))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE some_table.x = 10",
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ def test_execute_expanding_plus_literal_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x.in_(bindparam("q", expanding=True, literal_execute=True))
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[5, 6, 7]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE some_table.x IN (5, 6, 7)",
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ @testing.requires.tuple_in
+ def test_execute_tuple_expanding_plus_literal_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True, literal_execute=True)
+ )
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[(5, 10), (12, 18)]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE (some_table.x, some_table.y) "
+ "IN (%s(5, 10), (12, 18))"
+ % ("VALUES " if config.db.dialect.tuple_in_values else ""),
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ @testing.requires.tuple_in
+ def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True, literal_execute=True)
+ )
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE (some_table.x, some_table.z) "
+ "IN (%s(5, 'z1'), (12, 'z3'))"
+ % ("VALUES " if config.db.dialect.tuple_in_values else ""),
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+
+class ExpandingBoundInTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "z": "z1"},
+ {"id": 2, "x": 2, "y": 3, "z": "z2"},
+ {"id": 3, "x": 3, "y": 4, "z": "z3"},
+ {"id": 4, "x": 4, "y": 5, "z": "z4"},
+ ],
+ )
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_multiple_empty_sets_bindparam(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .where(table.c.y.in_(bindparam("p")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": [], "p": []})
+
+ def test_multiple_empty_sets_direct(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([]))
+ .where(table.c.y.in_([]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_direct(self):
+ table = self.tables.some_table
+
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)])
+ go([], [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_direct(self):
+ table = self.tables.some_table
+
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)])
+ go([], [])
+
+ def test_bound_in_scalar_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
+
+ def test_bound_in_scalar_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([2, 3, 4]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ def test_nonempty_in_plus_empty_notin(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([2, 3]))
+ .where(table.c.id.not_in([]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,)])
+
+ def test_empty_in_plus_notempty_notin(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([]))
+ .where(table.c.id.not_in([2, 3]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [])
+
+ def test_typed_str_in(self):
+ """test related to #7292.
+
+ as a type is given to the bound param, there is no ambiguity
+ to the type of element.
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", type_=String, expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ def test_untyped_str_in(self):
+ """test related to #7292.
+
+ for untyped expression, we look at the types of elements.
+ Test for Sequence to detect tuple in. but not strings or bytes!
+ as always....
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ [(2, "z2"), (3, "z3"), (4, "z4")]
+ )
+ )
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self):
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(
+ bindparam(
+ "q", type_=TupleType(Integer(), String()), expanding=True
+ )
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ def test_empty_set_against_integer_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ def test_empty_set_against_integer_direct(self):
+ table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+
+ def test_empty_set_against_integer_negation_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.not_in(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
+ def test_empty_set_against_integer_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_empty_set_against_string_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.z.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ def test_empty_set_against_string_direct(self):
+ table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+
+ def test_empty_set_against_string_negation_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.z.not_in(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
+ def test_empty_set_against_string_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_null_in_empty_set_is_false_bindparam(self, connection):
+ stmt = select(
+ case(
+ (
+ null().in_(bindparam("foo", value=())),
+ true(),
+ ),
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+ def test_null_in_empty_set_is_false_direct(self, connection):
+ stmt = select(
+ case(
+ (
+ null().in_([]),
+ true(),
+ ),
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+
+class LikeFunctionsTest(fixtures.TablesTest):
+ __backend__ = True
+
+ run_inserts = "once"
+ run_deletes = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "abcdefg"},
+ {"id": 2, "data": "ab/cdefg"},
+ {"id": 3, "data": "ab%cdefg"},
+ {"id": 4, "data": "ab_cdefg"},
+ {"id": 5, "data": "abcde/fg"},
+ {"id": 6, "data": "abcde%fg"},
+ {"id": 7, "data": "ab#cdefg"},
+ {"id": 8, "data": "ab9cdefg"},
+ {"id": 9, "data": "abcde#fg"},
+ {"id": 10, "data": "abcd9fg"},
+ {"id": 11, "data": None},
+ ],
+ )
+
+ def _test(self, expr, expected):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ rows = {
+ value
+ for value, in conn.execute(select(some_table.c.id).where(expr))
+ }
+
+ eq_(rows, expected)
+
+ def test_startswith_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
+
+ def test_startswith_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c", autoescape=True), {3})
+
+ def test_startswith_sqlexpr(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.startswith(literal_column("'ab%c'")),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ )
+
+ def test_startswith_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab##c", escape="#"), {7})
+
+ def test_startswith_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3})
+ self._test(col.startswith("ab#c", autoescape=True, escape="#"), {7})
+
+ def test_endswith_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg"), {1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ def test_endswith_sqlexpr(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9}
+ )
+
+ def test_endswith_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg", autoescape=True), {6})
+
+ def test_endswith_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e##fg", escape="#"), {9})
+
+ def test_endswith_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6})
+ self._test(col.endswith("e#fg", autoescape=True, escape="#"), {9})
+
+ def test_contains_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ def test_contains_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cde", autoescape=True), {3})
+
+ def test_contains_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b##cde", escape="#"), {7})
+
+ def test_contains_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
+ self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
+
+ @testing.requires.regexp_match
+ def test_not_regexp_match(self):
+ col = self.tables.some_table.c.data
+ self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10})
+
+ @testing.requires.regexp_replace
+ def test_regexp_replace(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.regexp_replace("a.cde", "FOO").contains("FOO"), {1, 5, 6, 9}
+ )
+
+ @testing.requires.regexp_match
+ @testing.combinations(
+ ("a.cde", {1, 5, 6, 9}),
+ ("abc", {1, 5, 6, 9, 10}),
+ ("^abc", {1, 5, 6, 9, 10}),
+ ("9cde", {8}),
+ ("^a", set(range(1, 11))),
+ ("(b|c)", set(range(1, 11))),
+ ("^(b|c)", set()),
+ )
+ def test_regexp_match(self, text, expected):
+ col = self.tables.some_table.c.data
+ self._test(col.regexp_match(text), expected)
+
+
+class ComputedColumnTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("computed_columns",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "square",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("side", Integer),
+ Column("area", Integer, Computed("side * side")),
+ Column("perimeter", Integer, Computed("4 * side")),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.square.insert(),
+ [{"id": 1, "side": 10}, {"id": 10, "side": 42}],
+ )
+
+ def test_select_all(self):
+ with config.db.connect() as conn:
+ res = conn.execute(
+ select(text("*"))
+ .select_from(self.tables.square)
+ .order_by(self.tables.square.c.id)
+ ).fetchall()
+ eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)])
+
+ def test_select_columns(self):
+ with config.db.connect() as conn:
+ res = conn.execute(
+ select(
+ self.tables.square.c.area, self.tables.square.c.perimeter
+ )
+ .select_from(self.tables.square)
+ .order_by(self.tables.square.c.id)
+ ).fetchall()
+ eq_(res, [(100, 40), (1764, 168)])
+
+
+class IdentityColumnTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("identity_columns",)
+ run_inserts = "once"
+ run_deletes = "once"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "tbl_a",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(
+ always=True, start=42, nominvalue=True, nomaxvalue=True
+ ),
+ primary_key=True,
+ ),
+ Column("desc", String(100)),
+ )
+ Table(
+ "tbl_b",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0),
+ primary_key=True,
+ ),
+ Column("desc", String(100)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.tbl_a.insert(),
+ [{"desc": "a"}, {"desc": "b"}],
+ )
+ connection.execute(
+ cls.tables.tbl_b.insert(),
+ [{"desc": "a"}, {"desc": "b"}],
+ )
+ connection.execute(
+ cls.tables.tbl_b.insert(),
+ [{"id": 42, "desc": "c"}],
+ )
+
+ def test_select_all(self, connection):
+ res = connection.execute(
+ select(text("*"))
+ .select_from(self.tables.tbl_a)
+ .order_by(self.tables.tbl_a.c.id)
+ ).fetchall()
+ eq_(res, [(42, "a"), (43, "b")])
+
+ res = connection.execute(
+ select(text("*"))
+ .select_from(self.tables.tbl_b)
+ .order_by(self.tables.tbl_b.c.id)
+ ).fetchall()
+ eq_(res, [(-5, "b"), (0, "a"), (42, "c")])
+
+ def test_select_columns(self, connection):
+
+ res = connection.execute(
+ select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id)
+ ).fetchall()
+ eq_(res, [(42,), (43,)])
+
+ @testing.requires.identity_columns_standard
+ def test_insert_always_error(self, connection):
+ def fn():
+ connection.execute(
+ self.tables.tbl_a.insert(),
+ [{"id": 200, "desc": "a"}],
+ )
+
+ assert_raises((DatabaseError, ProgrammingError), fn)
+
+
+class IdentityAutoincrementTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("autoincrement_without_sequence",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "tbl",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(),
+ primary_key=True,
+ autoincrement=True,
+ ),
+ Column("desc", String(100)),
+ )
+
+ def test_autoincrement_with_identity(self, connection):
+ res = connection.execute(self.tables.tbl.insert(), {"desc": "row"})
+ res = connection.execute(self.tables.tbl.select()).first()
+ eq_(res, (1, "row"))
+
+
+class ExistsTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "stuff",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.stuff.insert(),
+ [
+ {"id": 1, "data": "some data"},
+ {"id": 2, "data": "some data"},
+ {"id": 3, "data": "some data"},
+ {"id": 4, "data": "some other data"},
+ ],
+ )
+
+ def test_select_exists(self, connection):
+ stuff = self.tables.stuff
+ eq_(
+ connection.execute(
+ select(literal(1)).where(
+ exists().where(stuff.c.data == "some data")
+ )
+ ).fetchall(),
+ [(1,)],
+ )
+
+ def test_select_exists_false(self, connection):
+ stuff = self.tables.stuff
+ eq_(
+ connection.execute(
+ select(literal(1)).where(
+ exists().where(stuff.c.data == "no data")
+ )
+ ).fetchall(),
+ [],
+ )
+
+
+class DistinctOnTest(AssertsCompiledSQL, fixtures.TablesTest):
+ __backend__ = True
+
+ @testing.fails_if(testing.requires.supports_distinct_on)
+ def test_distinct_on(self):
+ stm = select("*").distinct(column("q")).select_from(table("foo"))
+ with testing.expect_deprecated(
+ "DISTINCT ON is currently supported only by the PostgreSQL "
+ ):
+ self.assert_compile(stm, "SELECT DISTINCT * FROM foo")
+
+
+class IsOrIsNotDistinctFromTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("supports_is_distinct_from",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "is_distinct_test",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("col_a", Integer, nullable=True),
+ Column("col_b", Integer, nullable=True),
+ )
+
+ @testing.combinations(
+ ("both_int_different", 0, 1, 1),
+ ("both_int_same", 1, 1, 0),
+ ("one_null_first", None, 1, 1),
+ ("one_null_second", 0, None, 1),
+ ("both_null", None, None, 0),
+ id_="iaaa",
+ argnames="col_a_value, col_b_value, expected_row_count_for_is",
+ )
+ def test_is_or_is_not_distinct_from(
+ self, col_a_value, col_b_value, expected_row_count_for_is, connection
+ ):
+ tbl = self.tables.is_distinct_test
+
+ connection.execute(
+ tbl.insert(),
+ [{"id": 1, "col_a": col_a_value, "col_b": col_b_value}],
+ )
+
+ result = connection.execute(
+ tbl.select().where(tbl.c.col_a.is_distinct_from(tbl.c.col_b))
+ ).fetchall()
+ eq_(
+ len(result),
+ expected_row_count_for_is,
+ )
+
+ expected_row_count_for_is_not = (
+ 1 if expected_row_count_for_is == 0 else 0
+ )
+ result = connection.execute(
+ tbl.select().where(tbl.c.col_a.is_not_distinct_from(tbl.c.col_b))
+ ).fetchall()
+ eq_(
+ len(result),
+ expected_row_count_for_is_not,
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
new file mode 100644
index 0000000..d6747d2
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -0,0 +1,282 @@
+from .. import config
+from .. import fixtures
+from ..assertions import eq_
+from ..assertions import is_true
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import inspect
+from ... import Integer
+from ... import MetaData
+from ... import Sequence
+from ... import String
+from ... import testing
+
+
+class SequenceTest(fixtures.TablesTest):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ run_create_tables = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "seq_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "seq_opt_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq", data_type=Integer, optional=True),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "seq_no_returning",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ )
+
+ if testing.requires.schemas.enabled:
+ Table(
+ "seq_no_returning_sch",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema=config.test_schema),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ schema=config.test_schema,
+ )
+
+ def test_insert_roundtrip(self, connection):
+ connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
+ self._assert_round_trip(self.tables.seq_pk, connection)
+
+ def test_insert_lastrowid(self, connection):
+ r = connection.execute(
+ self.tables.seq_pk.insert(), dict(data="some data")
+ )
+ eq_(
+ r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
+ )
+
+ def test_nextval_direct(self, connection):
+ r = connection.execute(self.tables.seq_pk.c.id.default)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
+ @requirements.sequences_optional
+ def test_optional_seq(self, connection):
+ r = connection.execute(
+ self.tables.seq_opt_pk.insert(), dict(data="some data")
+ )
+ eq_(r.inserted_primary_key, (1,))
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
+
+ def test_insert_roundtrip_no_implicit_returning(self, connection):
+ connection.execute(
+ self.tables.seq_no_returning.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.seq_no_returning, connection)
+
+ @testing.combinations((True,), (False,), argnames="implicit_returning")
+ @testing.requires.schemas
+ def test_insert_roundtrip_translate(self, connection, implicit_returning):
+
+ seq_no_returning = Table(
+ "seq_no_returning_sch",
+ MetaData(),
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema="alt_schema"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=implicit_returning,
+ schema="alt_schema",
+ )
+
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+ connection.execute(seq_no_returning.insert(), dict(data="some data"))
+ self._assert_round_trip(seq_no_returning, connection)
+
+ @testing.requires.schemas
+ def test_nextval_direct_schema_translate(self, connection):
+ seq = Sequence("noret_sch_id_seq", schema="alt_schema")
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+
+ r = connection.execute(seq)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
+
+class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ def test_literal_binds_inline_compile(self, connection):
+ table = Table(
+ "x",
+ MetaData(),
+ Column("y", Integer, Sequence("y_seq")),
+ Column("q", Integer),
+ )
+
+ stmt = table.insert().values(q=5)
+
+ seq_nextval = connection.dialect.statement_compiler(
+ statement=None, dialect=connection.dialect
+ ).visit_sequence(Sequence("y_seq"))
+ self.assert_compile(
+ stmt,
+ "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
+ literal_binds=True,
+ dialect=connection.dialect,
+ )
+
+
+class HasSequenceTest(fixtures.TablesTest):
+ run_deletes = None
+
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Sequence("user_id_seq", metadata=metadata)
+ Sequence(
+ "other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True
+ )
+ if testing.requires.schemas.enabled:
+ Sequence(
+ "user_id_seq", schema=config.test_schema, metadata=metadata
+ )
+ Sequence(
+ "schema_seq", schema=config.test_schema, metadata=metadata
+ )
+ Table(
+ "user_id_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ )
+
+ def test_has_sequence(self, connection):
+ eq_(
+ inspect(connection).has_sequence("user_id_seq"),
+ True,
+ )
+
+ def test_has_sequence_other_object(self, connection):
+ eq_(
+ inspect(connection).has_sequence("user_id_table"),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_schema(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "user_id_seq", schema=config.test_schema
+ ),
+ True,
+ )
+
+ def test_has_sequence_neg(self, connection):
+ eq_(
+ inspect(connection).has_sequence("some_sequence"),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_schemas_neg(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "some_sequence", schema=config.test_schema
+ ),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_default_not_in_remote(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "other_sequence", schema=config.test_schema
+ ),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_remote_not_in_default(self, connection):
+ eq_(
+ inspect(connection).has_sequence("schema_seq"),
+ False,
+ )
+
+ def test_get_sequence_names(self, connection):
+ exp = {"other_seq", "user_id_seq"}
+
+ res = set(inspect(connection).get_sequence_names())
+ is_true(res.intersection(exp) == exp)
+ is_true("schema_seq" not in res)
+
+ @testing.requires.schemas
+ def test_get_sequence_names_no_sequence_schema(self, connection):
+ eq_(
+ inspect(connection).get_sequence_names(
+ schema=config.test_schema_2
+ ),
+ [],
+ )
+
+ @testing.requires.schemas
+ def test_get_sequence_names_sequences_schema(self, connection):
+ eq_(
+ sorted(
+ inspect(connection).get_sequence_names(
+ schema=config.test_schema
+ )
+ ),
+ ["schema_seq", "user_id_seq"],
+ )
+
+
+class HasSequenceTestEmpty(fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ def test_get_sequence_names_no_sequence(self, connection):
+ eq_(
+ inspect(connection).get_sequence_names(),
+ [],
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
new file mode 100644
index 0000000..b96350e
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -0,0 +1,1508 @@
+# coding: utf-8
+
+import datetime
+import decimal
+import json
+import re
+
+from .. import config
+from .. import engines
+from .. import fixtures
+from .. import mock
+from ..assertions import eq_
+from ..assertions import is_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import and_
+from ... import BigInteger
+from ... import bindparam
+from ... import Boolean
+from ... import case
+from ... import cast
+from ... import Date
+from ... import DateTime
+from ... import Float
+from ... import Integer
+from ... import JSON
+from ... import literal
+from ... import MetaData
+from ... import null
+from ... import Numeric
+from ... import select
+from ... import String
+from ... import testing
+from ... import Text
+from ... import Time
+from ... import TIMESTAMP
+from ... import TypeDecorator
+from ... import Unicode
+from ... import UnicodeText
+from ... import util
+from ...orm import declarative_base
+from ...orm import Session
+from ...sql.sqltypes import LargeBinary
+from ...sql.sqltypes import PickleType
+from ...util import compat
+from ...util import u
+
+
+class _LiteralRoundTripFixture(object):
+ supports_whereclause = True
+
+ @testing.fixture
+ def literal_round_trip(self, metadata, connection):
+ """test literal rendering"""
+
+ # for literal, we test the literal render in an INSERT
+ # into a typed column. we can then SELECT it back as its
+ # official type; ideally we'd be able to use CAST here
+ # but MySQL in particular can't CAST fully
+
+ def run(type_, input_, output, filter_=None):
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+
+ for value in input_:
+ ins = (
+ t.insert()
+ .values(x=literal(value, type_))
+ .compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
+ )
+ connection.execute(ins)
+
+ if self.supports_whereclause:
+ stmt = t.select().where(t.c.x == literal(value))
+ else:
+ stmt = t.select()
+
+ stmt = stmt.compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
+ for row in connection.execute(stmt):
+ value = row[0]
+ if filter_ is not None:
+ value = filter_(value)
+ assert value in output
+
+ return run
+
+
+class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase):
+ __requires__ = ("unicode_data",)
+
+ data = u(
+ "Alors vous imaginez ma 🐍 surprise, au lever du jour, "
+ "quand une drôle de petite 🐍 voix m’a réveillé. Elle "
+ "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »"
+ )
+
+ @property
+ def supports_whereclause(self):
+ return config.requirements.expressions_against_unbounded_text.enabled
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "unicode_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("unicode_data", cls.datatype),
+ )
+
+ def test_round_trip(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": self.data}
+ )
+
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+
+ eq_(row, (self.data,))
+ assert isinstance(row[0], util.text_type)
+
+ def test_round_trip_executemany(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(),
+ [{"id": i, "unicode_data": self.data} for i in range(1, 4)],
+ )
+
+ rows = connection.execute(
+ select(unicode_table.c.unicode_data)
+ ).fetchall()
+ eq_(rows, [(self.data,) for i in range(1, 4)])
+ for row in rows:
+ assert isinstance(row[0], util.text_type)
+
+ def _test_null_strings(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": None}
+ )
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+ eq_(row, (None,))
+
+ def _test_empty_strings(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": u("")}
+ )
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+ eq_(row, (u(""),))
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(self.datatype, [self.data], [self.data])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+
+class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
+ __requires__ = ("unicode_data",)
+ __backend__ = True
+
+ datatype = Unicode(255)
+
+ @requirements.empty_strings_varchar
+ def test_empty_strings_varchar(self, connection):
+ self._test_empty_strings(connection)
+
+ def test_null_strings_varchar(self, connection):
+ self._test_null_strings(connection)
+
+
+class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
+ __requires__ = "unicode_data", "text_type"
+ __backend__ = True
+
+ datatype = UnicodeText()
+
+ @requirements.empty_strings_text
+ def test_empty_strings_text(self, connection):
+ self._test_empty_strings(connection)
+
+ def test_null_strings_text(self, connection):
+ self._test_null_strings(connection)
+
+
+class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("binary_literals",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "binary_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("binary_data", LargeBinary),
+ Column("pickle_data", PickleType),
+ )
+
+ def test_binary_roundtrip(self, connection):
+ binary_table = self.tables.binary_table
+
+ connection.execute(
+ binary_table.insert(), {"id": 1, "binary_data": b"this is binary"}
+ )
+ row = connection.execute(select(binary_table.c.binary_data)).first()
+ eq_(row, (b"this is binary",))
+
+ def test_pickle_roundtrip(self, connection):
+ binary_table = self.tables.binary_table
+
+ connection.execute(
+ binary_table.insert(),
+ {"id": 1, "pickle_data": {"foo": [1, 2, 3], "bar": "bat"}},
+ )
+ row = connection.execute(select(binary_table.c.pickle_data)).first()
+ eq_(row, ({"foo": [1, 2, 3], "bar": "bat"},))
+
+
+class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("text_type",)
+ __backend__ = True
+
+ @property
+ def supports_whereclause(self):
+ return config.requirements.expressions_against_unbounded_text.enabled
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "text_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("text_data", Text),
+ )
+
+ def test_text_roundtrip(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(
+ text_table.insert(), {"id": 1, "text_data": "some text"}
+ )
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, ("some text",))
+
+ @testing.requires.empty_strings_text
+ def test_text_empty_strings(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(text_table.insert(), {"id": 1, "text_data": ""})
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, ("",))
+
+ def test_text_null_strings(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(text_table.insert(), {"id": 1, "text_data": None})
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, (None,))
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(Text, ["some text"], ["some text"])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+ def test_literal_quoting(self, literal_round_trip):
+ data = """some 'text' hey "hi there" that's text"""
+ literal_round_trip(Text, [data], [data])
+
+ def test_literal_backslashes(self, literal_round_trip):
+ data = r"backslash one \ backslash two \\ end"
+ literal_round_trip(Text, [data], [data])
+
+ def test_literal_percentsigns(self, literal_round_trip):
+ data = r"percent % signs %% percent"
+ literal_round_trip(Text, [data], [data])
+
+
+class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @requirements.unbounded_varchar
+ def test_nolength_string(self):
+ metadata = MetaData()
+ foo = Table("foo", metadata, Column("one", String))
+
+ foo.create(config.db)
+ foo.drop(config.db)
+
+ def test_literal(self, literal_round_trip):
+ # note that in Python 3, this invokes the Unicode
+ # datatype for the literal part because all strings are unicode
+ literal_round_trip(String(40), ["some text"], ["some text"])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+ def test_literal_quoting(self, literal_round_trip):
+ data = """some 'text' hey "hi there" that's text"""
+ literal_round_trip(String(40), [data], [data])
+
+ def test_literal_backslashes(self, literal_round_trip):
+ data = r"backslash one \ backslash two \\ end"
+ literal_round_trip(String(40), [data], [data])
+
+
+class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
+ compare = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ class Decorated(TypeDecorator):
+ impl = cls.datatype
+ cache_ok = True
+
+ Table(
+ "date_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("date_data", cls.datatype),
+ Column("decorated_date_data", Decorated),
+ )
+
+ @testing.requires.datetime_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
+ def test_round_trip(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"id": 1, "date_data": self.data}
+ )
+
+ row = connection.execute(select(date_table.c.date_data)).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
+ def test_round_trip_decorated(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"id": 1, "decorated_date_data": self.data}
+ )
+
+ row = connection.execute(
+ select(date_table.c.decorated_date_data)
+ ).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
+ def test_null(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(date_table.insert(), {"id": 1, "date_data": None})
+
+ row = connection.execute(select(date_table.c.date_data)).first()
+ eq_(row, (None,))
+
+ @testing.requires.datetime_literals
+ def test_literal(self, literal_round_trip):
+ compare = self.compare or self.data
+ literal_round_trip(self.datatype, [self.data], [compare])
+
+ @testing.requires.standalone_null_binds_whereclause
+ def test_null_bound_comparison(self):
+ # this test is based on an Oracle issue observed in #4886.
+ # passing NULL for an expression that needs to be interpreted as
+ # a certain type, does the DBAPI have the info it needs to do this.
+ date_table = self.tables.date_table
+ with config.db.begin() as conn:
+ result = conn.execute(
+ date_table.insert(), {"id": 1, "date_data": self.data}
+ )
+ id_ = result.inserted_primary_key[0]
+ stmt = select(date_table.c.id).where(
+ case(
+ (
+ bindparam("foo", type_=self.datatype) != None,
+ bindparam("foo", type_=self.datatype),
+ ),
+ else_=date_table.c.date_data,
+ )
+ == date_table.c.date_data
+ )
+
+ row = conn.execute(stmt, {"foo": None}).first()
+ eq_(row[0], id_)
+
+
+class DateTimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+
+
+class DateTimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_timezone",)
+ __backend__ = True
+ datatype = DateTime(timezone=True)
+ data = datetime.datetime(
+ 2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc
+ )
+
+
+class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_microseconds",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
+
+class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("timestamp_microseconds",)
+ __backend__ = True
+ datatype = TIMESTAMP
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
+ @testing.requires.timestamp_microseconds_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
+
+class TimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time",)
+ __backend__ = True
+ datatype = Time
+ data = datetime.time(12, 57, 18)
+
+
+class TimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_timezone",)
+ __backend__ = True
+ datatype = Time(timezone=True)
+ data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc)
+
+
+class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_microseconds",)
+ __backend__ = True
+ datatype = Time
+ data = datetime.time(12, 57, 18, 396)
+
+
+class DateTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("date",)
+ __backend__ = True
+ datatype = Date
+ data = datetime.date(2012, 10, 15)
+
+
+class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = "date", "date_coerces_from_datetime"
+ __backend__ = True
+ datatype = Date
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+ compare = datetime.date(2012, 10, 15)
+
+
+class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_historic",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(1850, 11, 10, 11, 52, 35)
+
+
+class DateHistoricTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("date_historic",)
+ __backend__ = True
+ datatype = Date
+ data = datetime.date(1727, 4, 1)
+
+
+class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(Integer, [5], [5])
+
+ def test_huge_int(self, integer_round_trip):
+ integer_round_trip(BigInteger, 1376537018368127)
+
+ @testing.fixture
+ def integer_round_trip(self, metadata, connection):
+ def run(datatype, data):
+ int_table = Table(
+ "integer_table",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ primary_key=True,
+ test_needs_autoincrement=True,
+ ),
+ Column("integer_data", datatype),
+ )
+
+ metadata.create_all(config.db)
+
+ connection.execute(
+ int_table.insert(), {"id": 1, "integer_data": data}
+ )
+
+ row = connection.execute(select(int_table.c.integer_data)).first()
+
+ eq_(row, (data,))
+
+ if util.py3k:
+ assert isinstance(row[0], int)
+ else:
+ assert isinstance(row[0], (long, int)) # noqa
+
+ return run
+
+
+class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @testing.fixture
+ def string_as_int(self):
+ class StringAsInt(TypeDecorator):
+ impl = String(50)
+ cache_ok = True
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def column_expression(self, col):
+ return cast(col, Integer)
+
+ def bind_expression(self, col):
+ return cast(col, String(50))
+
+ return StringAsInt()
+
+ def test_special_type(self, metadata, connection, string_as_int):
+
+ type_ = string_as_int
+
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+
+ connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ eq_(result, {1, 2, 3})
+
+ result = {
+ row[0] for row in connection.execute(t.select().where(t.c.x == 2))
+ }
+ eq_(result, {2})
+
+
+class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @testing.fixture
+ def do_numeric_test(self, metadata, connection):
+ @testing.emits_warning(
+ r".*does \*not\* support Decimal objects natively"
+ )
+ def run(type_, input_, output, filter_=None, check_scale=False):
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+ connection.execute(t.insert(), [{"x": x} for x in input_])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ output = set(output)
+ if filter_:
+ result = set(filter_(x) for x in result)
+ output = set(filter_(x) for x in output)
+ eq_(result, output)
+ if check_scale:
+ eq_([str(x) for x in result], [str(x) for x in output])
+
+ return run
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_render_literal_numeric(self, literal_round_trip):
+ literal_round_trip(
+ Numeric(precision=8, scale=4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [decimal.Decimal("15.7563")],
+ )
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_render_literal_numeric_asfloat(self, literal_round_trip):
+ literal_round_trip(
+ Numeric(precision=8, scale=4, asdecimal=False),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ )
+
+ def test_render_literal_float(self, literal_round_trip):
+ literal_round_trip(
+ Float(4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
+ )
+
+ @testing.requires.precision_generic_float_type
+ def test_float_custom_scale(self, do_numeric_test):
+ do_numeric_test(
+ Float(None, decimal_return_scale=7, asdecimal=True),
+ [15.7563827, decimal.Decimal("15.7563827")],
+ [decimal.Decimal("15.7563827")],
+ check_scale=True,
+ )
+
+ def test_numeric_as_decimal(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [decimal.Decimal("15.7563")],
+ )
+
+ def test_numeric_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4, asdecimal=False),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ )
+
+ @testing.requires.infinity_floats
+ def test_infinity_floats(self, do_numeric_test):
+ """test for #977, #7283"""
+
+ do_numeric_test(
+ Float(None),
+ [float("inf")],
+ [float("inf")],
+ )
+
+ @testing.requires.fetch_null_from_numeric
+ def test_numeric_null_as_decimal(self, do_numeric_test):
+ do_numeric_test(Numeric(precision=8, scale=4), [None], [None])
+
+ @testing.requires.fetch_null_from_numeric
+ def test_numeric_null_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4, asdecimal=False), [None], [None]
+ )
+
+ @testing.requires.floats_to_four_decimals
+ def test_float_as_decimal(self, do_numeric_test):
+ do_numeric_test(
+ Float(precision=8, asdecimal=True),
+ [15.7563, decimal.Decimal("15.7563"), None],
+ [decimal.Decimal("15.7563"), None],
+ filter_=lambda n: n is not None and round(n, 4) or None,
+ )
+
+ def test_float_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Float(precision=8),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
+ )
+
+ def test_float_coerce_round_trip(self, connection):
+ expr = 15.7563
+
+ val = connection.scalar(select(literal(expr)))
+ eq_(val, expr)
+
+ # this does not work in MySQL, see #4036, however we choose not
+ # to render CAST unconditionally since this is kind of an edge case.
+
+ @testing.requires.implicit_decimal_binds
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_decimal_coerce_round_trip(self, connection):
+ expr = decimal.Decimal("15.7563")
+
+ val = connection.scalar(select(literal(expr)))
+ eq_(val, expr)
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_decimal_coerce_round_trip_w_cast(self, connection):
+ expr = decimal.Decimal("15.7563")
+
+ val = connection.scalar(select(cast(expr, Numeric(10, 4))))
+ eq_(val, expr)
+
+ @testing.requires.precision_numerics_general
+ def test_precision_decimal(self, do_numeric_test):
+ numbers = set(
+ [
+ decimal.Decimal("54.234246451650"),
+ decimal.Decimal("0.004354"),
+ decimal.Decimal("900.0"),
+ ]
+ )
+
+ do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers)
+
+ @testing.requires.precision_numerics_enotation_large
+ def test_enotation_decimal(self, do_numeric_test):
+ """test exceedingly small decimals.
+
+ Decimal reports values with E notation when the exponent
+ is greater than 6.
+
+ """
+
+ numbers = set(
+ [
+ decimal.Decimal("1E-2"),
+ decimal.Decimal("1E-3"),
+ decimal.Decimal("1E-4"),
+ decimal.Decimal("1E-5"),
+ decimal.Decimal("1E-6"),
+ decimal.Decimal("1E-7"),
+ decimal.Decimal("1E-8"),
+ decimal.Decimal("0.01000005940696"),
+ decimal.Decimal("0.00000005940696"),
+ decimal.Decimal("0.00000000000696"),
+ decimal.Decimal("0.70000000000696"),
+ decimal.Decimal("696E-12"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers)
+
+ @testing.requires.precision_numerics_enotation_large
+ def test_enotation_decimal_large(self, do_numeric_test):
+ """test exceedingly large decimals."""
+
+ numbers = set(
+ [
+ decimal.Decimal("4E+8"),
+ decimal.Decimal("5748E+15"),
+ decimal.Decimal("1.521E+15"),
+ decimal.Decimal("00000000000000.1E+12"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers)
+
+ @testing.requires.precision_numerics_many_significant_digits
+ def test_many_significant_digits(self, do_numeric_test):
+ numbers = set(
+ [
+ decimal.Decimal("31943874831932418390.01"),
+ decimal.Decimal("319438950232418390.273596"),
+ decimal.Decimal("87673.594069654243"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers)
+
+ @testing.requires.precision_numerics_retains_significant_digits
+ def test_numeric_no_decimal(self, do_numeric_test):
+ numbers = set([decimal.Decimal("1.000")])
+ do_numeric_test(
+ Numeric(precision=5, scale=3), numbers, numbers, check_scale=True
+ )
+
+
+class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "boolean_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("value", Boolean),
+ Column("unconstrained_value", Boolean(create_constraint=False)),
+ )
+
+ def test_render_literal_bool(self, literal_round_trip):
+ literal_round_trip(Boolean(), [True, False], [True, False])
+
+ def test_round_trip(self, connection):
+ boolean_table = self.tables.boolean_table
+
+ connection.execute(
+ boolean_table.insert(),
+ {"id": 1, "value": True, "unconstrained_value": False},
+ )
+
+ row = connection.execute(
+ select(boolean_table.c.value, boolean_table.c.unconstrained_value)
+ ).first()
+
+ eq_(row, (True, False))
+ assert isinstance(row[0], bool)
+
+ @testing.requires.nullable_booleans
+ def test_null(self, connection):
+ boolean_table = self.tables.boolean_table
+
+ connection.execute(
+ boolean_table.insert(),
+ {"id": 1, "value": None, "unconstrained_value": None},
+ )
+
+ row = connection.execute(
+ select(boolean_table.c.value, boolean_table.c.unconstrained_value)
+ ).first()
+
+ eq_(row, (None, None))
+
+ def test_whereclause(self):
+ # testing "WHERE <column>" renders a compatible expression
+ boolean_table = self.tables.boolean_table
+
+ with config.db.begin() as conn:
+ conn.execute(
+ boolean_table.insert(),
+ [
+ {"id": 1, "value": True, "unconstrained_value": True},
+ {"id": 2, "value": False, "unconstrained_value": False},
+ ],
+ )
+
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(boolean_table.c.value)
+ ),
+ 1,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(
+ boolean_table.c.unconstrained_value
+ )
+ ),
+ 1,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(~boolean_table.c.value)
+ ),
+ 2,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(
+ ~boolean_table.c.unconstrained_value
+ )
+ ),
+ 2,
+ )
+
+
+class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("json_type",)
+ __backend__ = True
+
+ datatype = JSON
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype, nullable=False),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
+
+ def test_round_trip_data1(self, connection):
+ self._test_round_trip({"key1": "value1", "key2": "value2"}, connection)
+
+ def _test_round_trip(self, data_element, connection):
+ data_table = self.tables.data_table
+
+ connection.execute(
+ data_table.insert(),
+ {"id": 1, "name": "row1", "data": data_element},
+ )
+
+ row = connection.execute(select(data_table.c.data)).first()
+
+ eq_(row, (data_element,))
+
+ def _index_fixtures(include_comparison):
+
+ if include_comparison:
+ # basically SQL Server and MariaDB can kind of do json
+ # comparison, MySQL, PG and SQLite can't. not worth it.
+ json_elements = []
+ else:
+ json_elements = [
+ ("json", {"foo": "bar"}),
+ ("json", ["one", "two", "three"]),
+ (None, {"foo": "bar"}),
+ (None, ["one", "two", "three"]),
+ ]
+
+ elements = [
+ ("boolean", True),
+ ("boolean", False),
+ ("boolean", None),
+ ("string", "some string"),
+ ("string", None),
+ ("string", util.u("réve illé")),
+ (
+ "string",
+ util.u("réve🐍 illé"),
+ testing.requires.json_index_supplementary_unicode_element,
+ ),
+ ("integer", 15),
+ ("integer", 1),
+ ("integer", 0),
+ ("integer", None),
+ ("float", 28.5),
+ ("float", None),
+ (
+ "float",
+ 1234567.89,
+ ),
+ ("numeric", 1234567.89),
+ # this one "works" because the float value you see here is
+ # lost immediately to floating point stuff
+ ("numeric", 99998969694839.983485848, requirements.python3),
+ ("numeric", 99939.983485848, requirements.python3),
+ ("_decimal", decimal.Decimal("1234567.89")),
+ (
+ "_decimal",
+ decimal.Decimal("99998969694839.983485848"),
+ # fails on SQLite and MySQL (non-mariadb)
+ requirements.cast_precision_numerics_many_significant_digits,
+ ),
+ (
+ "_decimal",
+ decimal.Decimal("99939.983485848"),
+ ),
+ ] + json_elements
+
+ def decorate(fn):
+ fn = testing.combinations(id_="sa", *elements)(fn)
+
+ return fn
+
+ return decorate
+
+ def _json_value_insert(self, connection, datatype, value, data_element):
+ data_table = self.tables.data_table
+ if datatype == "_decimal":
+
+ # Python's builtin json serializer basically doesn't support
+ # Decimal objects without implicit float conversion period.
+ # users can otherwise use simplejson which supports
+ # precision decimals
+
+ # https://bugs.python.org/issue16535
+
+ # inserting as strings to avoid a new fixture around the
+ # dialect which would have idiosyncrasies for different
+ # backends.
+
+ class DecimalEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, decimal.Decimal):
+ return str(o)
+ return super(DecimalEncoder, self).default(o)
+
+ json_data = json.dumps(data_element, cls=DecimalEncoder)
+
+ # take the quotes out. yup, there is *literally* no other
+ # way to get Python's json.dumps() to put all the digits in
+ # the string
+ json_data = re.sub(r'"(%s)"' % str(value), str(value), json_data)
+
+ datatype = "numeric"
+
+ connection.execute(
+ data_table.insert().values(
+ name="row1",
+ # to pass the string directly to every backend, including
+ # PostgreSQL which needs the value to be CAST as JSON
+ # both in the SQL as well as at the prepared statement
+ # level for asyncpg, while at the same time MySQL
+ # doesn't even support CAST for JSON, here we are
+ # sending the string embedded in the SQL without using
+ # a parameter.
+ data=bindparam(None, json_data, literal_execute=True),
+ nulldata=bindparam(None, json_data, literal_execute=True),
+ ),
+ )
+ else:
+ connection.execute(
+ data_table.insert(),
+ {
+ "name": "row1",
+ "data": data_element,
+ "nulldata": data_element,
+ },
+ )
+
+ p_s = None
+
+ if datatype:
+ if datatype == "numeric":
+ a, b = str(value).split(".")
+ s = len(b)
+ p = len(a) + s
+
+ if isinstance(value, decimal.Decimal):
+ compare_value = value
+ else:
+ compare_value = decimal.Decimal(str(value))
+
+ p_s = (p, s)
+ else:
+ compare_value = value
+ else:
+ compare_value = value
+
+ return datatype, compare_value, p_s
+
+ @_index_fixtures(False)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_index_typed_access(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": value}
+
+ with config.db.begin() as conn:
+
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data["key1"]
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ roundtrip = conn.scalar(select(expr))
+ eq_(roundtrip, compare_value)
+ if util.py3k: # skip py2k to avoid comparing unicode to str etc.
+ is_(type(roundtrip), type(compare_value))
+
+ @_index_fixtures(True)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_index_typed_comparison(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": value}
+
+ with config.db.begin() as conn:
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data["key1"]
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ row = conn.execute(
+ select(expr).where(expr == compare_value)
+ ).first()
+
+ # make sure we get a row even if value is None
+ eq_(row, (compare_value,))
+
+ @_index_fixtures(True)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_path_typed_comparison(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": {"subkey1": value}}
+ with config.db.begin() as conn:
+
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data[("key1", "subkey1")]
+
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ row = conn.execute(
+ select(expr).where(expr == compare_value)
+ ).first()
+
+ # make sure we get a row even if value is None
+ eq_(row, (compare_value,))
+
+ @testing.combinations(
+ (True,),
+ (False,),
+ (None,),
+ (15,),
+ (0,),
+ (-1,),
+ (-1.0,),
+ (15.052,),
+ ("a string",),
+ (util.u("réve illé"),),
+ (util.u("réve🐍 illé"),),
+ )
+ def test_single_element_round_trip(self, element):
+ data_table = self.tables.data_table
+ data_element = element
+ with config.db.begin() as conn:
+ conn.execute(
+ data_table.insert(),
+ {
+ "name": "row1",
+ "data": data_element,
+ "nulldata": data_element,
+ },
+ )
+
+ row = conn.execute(
+ select(data_table.c.data, data_table.c.nulldata)
+ ).first()
+
+ eq_(row, (data_element, data_element))
+
+ def test_round_trip_custom_json(self):
+ data_table = self.tables.data_table
+ data_element = {"key1": "data1"}
+
+ js = mock.Mock(side_effect=json.dumps)
+ jd = mock.Mock(side_effect=json.loads)
+ engine = engines.testing_engine(
+ options=dict(json_serializer=js, json_deserializer=jd)
+ )
+
+ # support sqlite :memory: database...
+ data_table.create(engine, checkfirst=True)
+ with engine.begin() as conn:
+ conn.execute(
+ data_table.insert(), {"name": "row1", "data": data_element}
+ )
+ row = conn.execute(select(data_table.c.data)).first()
+
+ eq_(row, (data_element,))
+ eq_(js.mock_calls, [mock.call(data_element)])
+ eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
+
+ @testing.combinations(
+ ("parameters",),
+ ("multiparameters",),
+ ("values",),
+ ("omit",),
+ argnames="insert_type",
+ )
+ def test_round_trip_none_as_sql_null(self, connection, insert_type):
+ col = self.tables.data_table.c["nulldata"]
+
+ conn = connection
+
+ if insert_type == "parameters":
+ stmt, params = self.tables.data_table.insert(), {
+ "name": "r1",
+ "nulldata": None,
+ "data": None,
+ }
+ elif insert_type == "multiparameters":
+ stmt, params = self.tables.data_table.insert(), [
+ {"name": "r1", "nulldata": None, "data": None}
+ ]
+ elif insert_type == "values":
+ stmt, params = (
+ self.tables.data_table.insert().values(
+ name="r1",
+ nulldata=None,
+ data=None,
+ ),
+ {},
+ )
+ elif insert_type == "omit":
+ stmt, params = (
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": None},
+ )
+
+ else:
+ assert False
+
+ conn.execute(stmt, params)
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(col.is_(null()))
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ def test_round_trip_json_null_as_json_null(self, connection):
+ col = self.tables.data_table.c["data"]
+
+ conn = connection
+ conn.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": JSON.NULL},
+ )
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(
+ cast(col, String) == "null"
+ )
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ @testing.combinations(
+ ("parameters",),
+ ("multiparameters",),
+ ("values",),
+ argnames="insert_type",
+ )
+ def test_round_trip_none_as_json_null(self, connection, insert_type):
+ col = self.tables.data_table.c["data"]
+
+ if insert_type == "parameters":
+ stmt, params = self.tables.data_table.insert(), {
+ "name": "r1",
+ "data": None,
+ }
+ elif insert_type == "multiparameters":
+ stmt, params = self.tables.data_table.insert(), [
+ {"name": "r1", "data": None}
+ ]
+ elif insert_type == "values":
+ stmt, params = (
+ self.tables.data_table.insert().values(name="r1", data=None),
+ {},
+ )
+ else:
+ assert False
+
+ conn = connection
+ conn.execute(stmt, params)
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(
+ cast(col, String) == "null"
+ )
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ def test_unicode_round_trip(self):
+ # note we include Unicode supplementary characters as well
+ with config.db.begin() as conn:
+ conn.execute(
+ self.tables.data_table.insert(),
+ {
+ "name": "r1",
+ "data": {
+ util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+ "data": {"k1": util.u("drôl🐍e")},
+ },
+ },
+ )
+
+ eq_(
+ conn.scalar(select(self.tables.data_table.c.data)),
+ {
+ util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+ "data": {"k1": util.u("drôl🐍e")},
+ },
+ )
+
+ def test_eval_none_flag_orm(self, connection):
+
+ Base = declarative_base()
+
+ class Data(Base):
+ __table__ = self.tables.data_table
+
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
+
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
+ )
+
+
+class JSONLegacyStringCastIndexTest(
+ _LiteralRoundTripFixture, fixtures.TablesTest
+):
+ """test JSON index access with "cast to string", which we have documented
+ for a long time as how to compare JSON values, but is ultimately not
+ reliable in all cases. The "as_XYZ()" comparators should be used
+ instead.
+
+ """
+
+ __requires__ = ("json_type", "legacy_unconditional_json_extract")
+ __backend__ = True
+
+ datatype = JSON
+
+ data1 = {"key1": "value1", "key2": "value2"}
+
+ data2 = {
+ "Key 'One'": "value1",
+ "key two": "value2",
+ "key three": "value ' three '",
+ }
+
+ data3 = {
+ "key1": [1, 2, 3],
+ "key2": ["one", "two", "three"],
+ "key3": [{"four": "five"}, {"six": "seven"}],
+ }
+
+ data4 = ["one", "two", "three"]
+
+ data5 = {
+ "nested": {
+ "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}],
+ "elem2": {"elem3": {"elem4": "elem5"}},
+ }
+ }
+
+ data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}}
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
+
+ def _criteria_fixture(self):
+ with config.db.begin() as conn:
+ conn.execute(
+ self.tables.data_table.insert(),
+ [
+ {"name": "r1", "data": self.data1},
+ {"name": "r2", "data": self.data2},
+ {"name": "r3", "data": self.data3},
+ {"name": "r4", "data": self.data4},
+ {"name": "r5", "data": self.data5},
+ {"name": "r6", "data": self.data6},
+ ],
+ )
+
+ def _test_index_criteria(self, crit, expected, test_literal=True):
+ self._criteria_fixture()
+ with config.db.connect() as conn:
+ stmt = select(self.tables.data_table.c.name).where(crit)
+
+ eq_(conn.scalar(stmt), expected)
+
+ if test_literal:
+ literal_sql = str(
+ stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}
+ )
+ )
+
+ eq_(conn.exec_driver_sql(literal_sql).scalar(), expected)
+
+ def test_string_cast_crit_spaces_in_key(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ # limit the rows here to avoid PG error
+ # "cannot extract field from a non-object", which is
+ # fixed in 9.4 but may exist in 9.3
+ self._test_index_criteria(
+ and_(
+ name.in_(["r1", "r2", "r3"]),
+ cast(col["key two"], String) == '"value2"',
+ ),
+ "r2",
+ )
+
+ @config.requirements.json_array_indexes
+ def test_string_cast_crit_simple_int(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ # limit the rows here to avoid PG error
+ # "cannot extract array element from a non-array", which is
+ # fixed in 9.4 but may exist in 9.3
+ self._test_index_criteria(
+ and_(
+ name == "r4",
+ cast(col[1], String) == '"two"',
+ ),
+ "r4",
+ )
+
+ def test_string_cast_crit_mixed_path(self):
+ col = self.tables.data_table.c["data"]
+ self._test_index_criteria(
+ cast(col[("key3", 1, "six")], String) == '"seven"',
+ "r3",
+ )
+
+ def test_string_cast_crit_string_path(self):
+ col = self.tables.data_table.c["data"]
+ self._test_index_criteria(
+ cast(col[("nested", "elem2", "elem3", "elem4")], String)
+ == '"elem5"',
+ "r5",
+ )
+
+ def test_string_cast_crit_against_string_basic(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ self._test_index_criteria(
+ and_(
+ name == "r6",
+ cast(col["b"], String) == '"some value"',
+ ),
+ "r6",
+ )
+
+
+__all__ = (
+ "BinaryTest",
+ "UnicodeVarcharTest",
+ "UnicodeTextTest",
+ "JSONTest",
+ "JSONLegacyStringCastIndexTest",
+ "DateTest",
+ "DateTimeTest",
+ "DateTimeTZTest",
+ "TextTest",
+ "NumericTest",
+ "IntegerTest",
+ "CastTypeDecoratorTest",
+ "DateTimeHistoricTest",
+ "DateTimeCoercedToDateTimeTest",
+ "TimeMicrosecondsTest",
+ "TimestampMicrosecondsTest",
+ "TimeTest",
+ "TimeTZTest",
+ "DateTimeMicrosecondsTest",
+ "DateHistoricTest",
+ "StringTest",
+ "BooleanTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_unicode_ddl.py b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py
new file mode 100644
index 0000000..a4ae334
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py
@@ -0,0 +1,206 @@
+# coding: utf-8
+"""verrrrry basic unicode column name testing"""
+
+from sqlalchemy import desc
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import testing
+from sqlalchemy import util
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+from sqlalchemy.util import u
+from sqlalchemy.util import ue
+
+
+class UnicodeSchemaTest(fixtures.TablesTest):
+ __requires__ = ("unicode_ddl",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ global t1, t2, t3
+
+ t1 = Table(
+ u("unitable1"),
+ metadata,
+ Column(u("méil"), Integer, primary_key=True),
+ Column(ue("\u6e2c\u8a66"), Integer),
+ test_needs_fk=True,
+ )
+ t2 = Table(
+ u("Unitéble2"),
+ metadata,
+ Column(u("méil"), Integer, primary_key=True, key="a"),
+ Column(
+ ue("\u6e2c\u8a66"),
+ Integer,
+ ForeignKey(u("unitable1.méil")),
+ key="b",
+ ),
+ test_needs_fk=True,
+ )
+
+ # Few DBs support Unicode foreign keys
+ if testing.against("sqlite"):
+ t3 = Table(
+ ue("\u6e2c\u8a66"),
+ metadata,
+ Column(
+ ue("\u6e2c\u8a66_id"),
+ Integer,
+ primary_key=True,
+ autoincrement=False,
+ ),
+ Column(
+ ue("unitable1_\u6e2c\u8a66"),
+ Integer,
+ ForeignKey(ue("unitable1.\u6e2c\u8a66")),
+ ),
+ Column(
+ u("Unitéble2_b"), Integer, ForeignKey(u("Unitéble2.b"))
+ ),
+ Column(
+ ue("\u6e2c\u8a66_self"),
+ Integer,
+ ForeignKey(ue("\u6e2c\u8a66.\u6e2c\u8a66_id")),
+ ),
+ test_needs_fk=True,
+ )
+ else:
+ t3 = Table(
+ ue("\u6e2c\u8a66"),
+ metadata,
+ Column(
+ ue("\u6e2c\u8a66_id"),
+ Integer,
+ primary_key=True,
+ autoincrement=False,
+ ),
+ Column(ue("unitable1_\u6e2c\u8a66"), Integer),
+ Column(u("Unitéble2_b"), Integer),
+ Column(ue("\u6e2c\u8a66_self"), Integer),
+ test_needs_fk=True,
+ )
+
+ def test_insert(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
+ eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
+ eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
+
+ def test_col_targeting(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ row = connection.execute(t1.select()).first()
+ eq_(row._mapping[t1.c[u("méil")]], 1)
+ eq_(row._mapping[t1.c[ue("\u6e2c\u8a66")]], 5)
+
+ row = connection.execute(t2.select()).first()
+ eq_(row._mapping[t2.c[u("a")]], 1)
+ eq_(row._mapping[t2.c[u("b")]], 1)
+
+ row = connection.execute(t3.select()).first()
+ eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_id")]], 1)
+ eq_(row._mapping[t3.c[ue("unitable1_\u6e2c\u8a66")]], 5)
+ eq_(row._mapping[t3.c[u("Unitéble2_b")]], 1)
+ eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_self")]], 1)
+
+ def test_reflect(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 2, ue("\u6e2c\u8a66"): 7})
+ connection.execute(t2.insert(), {u("a"): 2, u("b"): 2})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 2,
+ ue("unitable1_\u6e2c\u8a66"): 7,
+ u("Unitéble2_b"): 2,
+ ue("\u6e2c\u8a66_self"): 2,
+ },
+ )
+
+ meta = MetaData()
+ tt1 = Table(t1.name, meta, autoload_with=connection)
+ tt2 = Table(t2.name, meta, autoload_with=connection)
+ tt3 = Table(t3.name, meta, autoload_with=connection)
+
+ connection.execute(tt1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(tt2.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 1})
+ connection.execute(
+ tt3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ eq_(
+ connection.execute(
+ tt1.select().order_by(desc(u("méil")))
+ ).fetchall(),
+ [(2, 7), (1, 5)],
+ )
+ eq_(
+ connection.execute(
+ tt2.select().order_by(desc(u("méil")))
+ ).fetchall(),
+ [(2, 2), (1, 1)],
+ )
+ eq_(
+ connection.execute(
+ tt3.select().order_by(desc(ue("\u6e2c\u8a66_id")))
+ ).fetchall(),
+ [(2, 7, 2, 2), (1, 5, 1, 1)],
+ )
+
+ def test_repr(self):
+ meta = MetaData()
+ t = Table(
+ ue("\u6e2c\u8a66"), meta, Column(ue("\u6e2c\u8a66_id"), Integer)
+ )
+
+ if util.py2k:
+ eq_(
+ repr(t),
+ (
+ "Table('\\u6e2c\\u8a66', MetaData(), "
+ "Column('\\u6e2c\\u8a66_id', Integer(), "
+ "table=<\u6e2c\u8a66>), "
+ "schema=None)"
+ ),
+ )
+ else:
+ eq_(
+ repr(t),
+ (
+ "Table('測試', MetaData(), "
+ "Column('測試_id', Integer(), "
+ "table=<測試>), "
+ "schema=None)"
+ ),
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
new file mode 100644
index 0000000..f04a9d5
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -0,0 +1,60 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import String
+
+
+class SimpleUpdateDeleteTest(fixtures.TablesTest):
+ run_deletes = "each"
+ __requires__ = ("sane_rowcount",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.plain_pk.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ def test_update(self, connection):
+ t = self.tables.plain_pk
+ r = connection.execute(
+ t.update().where(t.c.id == 2), dict(data="d2_new")
+ )
+ assert not r.is_insert
+ assert not r.returns_rows
+ assert r.rowcount == 1
+
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "d1"), (2, "d2_new"), (3, "d3")],
+ )
+
+ def test_delete(self, connection):
+ t = self.tables.plain_pk
+ r = connection.execute(t.delete().where(t.c.id == 2))
+ assert not r.is_insert
+ assert not r.returns_rows
+ assert r.rowcount == 1
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "d1"), (3, "d3")],
+ )
+
+
+__all__ = ("SimpleUpdateDeleteTest",)
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
new file mode 100644
index 0000000..be89bc6
--- /dev/null
+++ b/lib/sqlalchemy/testing/util.py
@@ -0,0 +1,458 @@
+# testing/util.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import decimal
+import gc
+import random
+import sys
+import types
+
+from . import config
+from . import mock
+from .. import inspect
+from ..engine import Connection
+from ..schema import Column
+from ..schema import DropConstraint
+from ..schema import DropTable
+from ..schema import ForeignKeyConstraint
+from ..schema import MetaData
+from ..schema import Table
+from ..sql import schema
+from ..sql.sqltypes import Integer
+from ..util import decorator
+from ..util import defaultdict
+from ..util import has_refcount_gc
+from ..util import inspect_getfullargspec
+from ..util import py2k
+
+
+if not has_refcount_gc:
+
+ def non_refcount_gc_collect(*args):
+ gc.collect()
+ gc.collect()
+
+ gc_collect = lazy_gc = non_refcount_gc_collect
+else:
+ # assume CPython - straight gc.collect, lazy_gc() is a pass
+ gc_collect = gc.collect
+
+ def lazy_gc():
+ pass
+
+
+def picklers():
+ picklers = set()
+ if py2k:
+ try:
+ import cPickle
+
+ picklers.add(cPickle)
+ except ImportError:
+ pass
+
+ import pickle
+
+ picklers.add(pickle)
+
+ # yes, this thing needs this much testing
+ for pickle_ in picklers:
+ for protocol in range(-2, pickle.HIGHEST_PROTOCOL):
+ yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
+
+
+if py2k:
+
+ def random_choices(population, k=1):
+ pop = list(population)
+ # lame but works :)
+ random.shuffle(pop)
+ return pop[0:k]
+
+
+else:
+
+ def random_choices(population, k=1):
+ return random.choices(population, k=k)
+
+
+def round_decimal(value, prec):
+ if isinstance(value, float):
+ return round(value, prec)
+
+ # can also use shift() here but that is 2.6 only
+ return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
+ decimal.ROUND_FLOOR
+ ) / pow(10, prec)
+
+
+class RandomSet(set):
+ def __iter__(self):
+ l = list(set.__iter__(self))
+ random.shuffle(l)
+ return iter(l)
+
+ def pop(self):
+ index = random.randint(0, len(self) - 1)
+ item = list(set.__iter__(self))[index]
+ self.remove(item)
+ return item
+
+ def union(self, other):
+ return RandomSet(set.union(self, other))
+
+ def difference(self, other):
+ return RandomSet(set.difference(self, other))
+
+ def intersection(self, other):
+ return RandomSet(set.intersection(self, other))
+
+ def copy(self):
+ return RandomSet(self)
+
+
+def conforms_partial_ordering(tuples, sorted_elements):
+ """True if the given sorting conforms to the given partial ordering."""
+
+ deps = defaultdict(set)
+ for parent, child in tuples:
+ deps[parent].add(child)
+ for i, node in enumerate(sorted_elements):
+ for n in sorted_elements[i:]:
+ if node in deps[n]:
+ return False
+ else:
+ return True
+
+
+def all_partial_orderings(tuples, elements):
+ edges = defaultdict(set)
+ for parent, child in tuples:
+ edges[child].add(parent)
+
+ def _all_orderings(elements):
+
+ if len(elements) == 1:
+ yield list(elements)
+ else:
+ for elem in elements:
+ subset = set(elements).difference([elem])
+ if not subset.intersection(edges[elem]):
+ for sub_ordering in _all_orderings(subset):
+ yield [elem] + sub_ordering
+
+ return iter(_all_orderings(elements))
+
+
+def function_named(fn, name):
+ """Return a function with a given __name__.
+
+ Will assign to __name__ and return the original function if possible on
+ the Python implementation, otherwise a new function will be constructed.
+
+ This function should be phased out as much as possible
+ in favor of @decorator. Tests that "generate" many named tests
+ should be modernized.
+
+ """
+ try:
+ fn.__name__ = name
+ except TypeError:
+ fn = types.FunctionType(
+ fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
+ )
+ return fn
+
+
+def run_as_contextmanager(ctx, fn, *arg, **kw):
+ """Run the given function under the given contextmanager,
+ simulating the behavior of 'with' to support older
+ Python versions.
+
+ This is not necessary anymore as we have placed 2.6
+ as minimum Python version, however some tests are still using
+ this structure.
+
+ """
+
+ obj = ctx.__enter__()
+ try:
+ result = fn(obj, *arg, **kw)
+ ctx.__exit__(None, None, None)
+ return result
+ except:
+ exc_info = sys.exc_info()
+ raise_ = ctx.__exit__(*exc_info)
+ if not raise_:
+ raise
+ else:
+ return raise_
+
+
+def rowset(results):
+ """Converts the results of sql execution into a plain set of column tuples.
+
+ Useful for asserting the results of an unordered query.
+ """
+
+ return {tuple(row) for row in results}
+
+
+def fail(msg):
+ assert False, msg
+
+
+@decorator
+def provide_metadata(fn, *args, **kw):
+ """Provide bound MetaData for a single test, dropping afterwards.
+
+ Legacy; use the "metadata" pytest fixture.
+
+ """
+
+ from . import fixtures
+
+ metadata = schema.MetaData()
+ self = args[0]
+ prev_meta = getattr(self, "metadata", None)
+ self.metadata = metadata
+ try:
+ return fn(*args, **kw)
+ finally:
+ # close out some things that get in the way of dropping tables.
+ # when using the "metadata" fixture, there is a set ordering
+ # of things that makes sure things are cleaned up in order, however
+ # the simple "decorator" nature of this legacy function means
+ # we have to hardcode some of that cleanup ahead of time.
+
+ # close ORM sessions
+ fixtures._close_all_sessions()
+
+ # integrate with the "connection" fixture as there are many
+ # tests where it is used along with provide_metadata
+ if fixtures._connection_fixture_connection:
+ # TODO: this warning can be used to find all the places
+ # this is used with connection fixture
+ # warn("mixing legacy provide metadata with connection fixture")
+ drop_all_tables_from_metadata(
+ metadata, fixtures._connection_fixture_connection
+ )
+ # as the provide_metadata fixture is often used with "testing.db",
+ # when we do the drop we have to commit the transaction so that
+ # the DB is actually updated as the CREATE would have been
+ # committed
+ fixtures._connection_fixture_connection.get_transaction().commit()
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
+ self.metadata = prev_meta
+
+
+def flag_combinations(*combinations):
+ """A facade around @testing.combinations() oriented towards boolean
+ keyword-based arguments.
+
+ Basically generates a nice looking identifier based on the keywords
+ and also sets up the argument names.
+
+ E.g.::
+
+ @testing.flag_combinations(
+ dict(lazy=False, passive=False),
+ dict(lazy=True, passive=False),
+ dict(lazy=False, passive=True),
+ dict(lazy=False, passive=True, raiseload=True),
+ )
+
+
+ would result in::
+
+ @testing.combinations(
+ ('', False, False, False),
+ ('lazy', True, False, False),
+ ('lazy_passive', True, True, False),
+ ('lazy_passive', True, True, True),
+ id_='iaaa',
+ argnames='lazy,passive,raiseload'
+ )
+
+ """
+
+ keys = set()
+
+ for d in combinations:
+ keys.update(d)
+
+ keys = sorted(keys)
+
+ return config.combinations(
+ *[
+ ("_".join(k for k in keys if d.get(k, False)),)
+ + tuple(d.get(k, False) for k in keys)
+ for d in combinations
+ ],
+ id_="i" + ("a" * len(keys)),
+ argnames=",".join(keys)
+ )
+
+
+def lambda_combinations(lambda_arg_sets, **kw):
+ args = inspect_getfullargspec(lambda_arg_sets)
+
+ arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]])
+
+ def create_fixture(pos):
+ def fixture(**kw):
+ return lambda_arg_sets(**kw)[pos]
+
+ fixture.__name__ = "fixture_%3.3d" % pos
+ return fixture
+
+ return config.combinations(
+ *[(create_fixture(i),) for i in range(len(arg_sets))], **kw
+ )
+
+
+def resolve_lambda(__fn, **kw):
+ """Given a no-arg lambda and a namespace, return a new lambda that
+ has all the values filled in.
+
+ This is used so that we can have module-level fixtures that
+ refer to instance-level variables using lambdas.
+
+ """
+
+ pos_args = inspect_getfullargspec(__fn)[0]
+ pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
+ glb = dict(__fn.__globals__)
+ glb.update(kw)
+ new_fn = types.FunctionType(__fn.__code__, glb)
+ return new_fn(**pass_pos_args)
+
+
+def metadata_fixture(ddl="function"):
+ """Provide MetaData for a pytest fixture."""
+
+ def decorate(fn):
+ def run_ddl(self):
+
+ metadata = self.metadata = schema.MetaData()
+ try:
+ result = fn(self, metadata)
+ metadata.create_all(config.db)
+ # TODO:
+ # somehow get a per-function dml erase fixture here
+ yield result
+ finally:
+ metadata.drop_all(config.db)
+
+ return config.fixture(scope=ddl)(run_ddl)
+
+ return decorate
+
+
+def force_drop_names(*names):
+ """Force the given table names to be dropped after test complete,
+ isolating for foreign key cycles
+
+ """
+
+ @decorator
+ def go(fn, *args, **kw):
+
+ try:
+ return fn(*args, **kw)
+ finally:
+ drop_all_tables(config.db, inspect(config.db), include_names=names)
+
+ return go
+
+
+class adict(dict):
+ """Dict keys available as attributes. Shadows."""
+
+ def __getattribute__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ return dict.__getattribute__(self, key)
+
+ def __call__(self, *keys):
+ return tuple([self[key] for key in keys])
+
+ get_all = __call__
+
+
+def drop_all_tables_from_metadata(metadata, engine_or_connection):
+ from . import engines
+
+ def go(connection):
+ engines.testing_reaper.prepare_for_drop_tables(connection)
+
+ if not connection.dialect.supports_alter:
+ from . import assertions
+
+ with assertions.expect_warnings(
+ "Can't sort tables", assert_=False
+ ):
+ metadata.drop_all(connection)
+ else:
+ metadata.drop_all(connection)
+
+ if not isinstance(engine_or_connection, Connection):
+ with engine_or_connection.begin() as connection:
+ go(connection)
+ else:
+ go(engine_or_connection)
+
+
+def drop_all_tables(engine, inspector, schema=None, include_names=None):
+
+ if include_names is not None:
+ include_names = set(include_names)
+
+ with engine.begin() as conn:
+ for tname, fkcs in reversed(
+ inspector.get_sorted_table_and_fkc_names(schema=schema)
+ ):
+ if tname:
+ if include_names is not None and tname not in include_names:
+ continue
+ conn.execute(
+ DropTable(Table(tname, MetaData(), schema=schema))
+ )
+ elif fkcs:
+ if not engine.dialect.supports_alter:
+ continue
+ for tname, fkc in fkcs:
+ if (
+ include_names is not None
+ and tname not in include_names
+ ):
+ continue
+ tb = Table(
+ tname,
+ MetaData(),
+ Column("x", Integer),
+ Column("y", Integer),
+ schema=schema,
+ )
+ conn.execute(
+ DropConstraint(
+ ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
+ )
+ )
+
+
+def teardown_events(event_cls):
+ @decorator
+ def decorate(fn, *arg, **kw):
+ try:
+ return fn(*arg, **kw)
+ finally:
+ event_cls._clear()
+
+ return decorate
diff --git a/lib/sqlalchemy/testing/warnings.py b/lib/sqlalchemy/testing/warnings.py
new file mode 100644
index 0000000..3e78387
--- /dev/null
+++ b/lib/sqlalchemy/testing/warnings.py
@@ -0,0 +1,82 @@
+# testing/warnings.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+from __future__ import absolute_import
+
+import warnings
+
+from . import assertions
+from .. import exc as sa_exc
+from ..util.langhelpers import _warnings_warn
+
+
+class SATestSuiteWarning(Warning):
+ """warning for a condition detected during tests that is non-fatal
+
+ Currently outside of SAWarning so that we can work around tools like
+ Alembic doing the wrong thing with warnings.
+
+ """
+
+
+def warn_test_suite(message):
+ _warnings_warn(message, category=SATestSuiteWarning)
+
+
+def setup_filters():
+ """Set global warning behavior for the test suite."""
+
+ # TODO: at this point we can use the normal pytest warnings plugin,
+ # if we decide the test suite can be linked to pytest only
+
+ origin = r"^(?:test|sqlalchemy)\..*"
+
+ warnings.filterwarnings(
+ "ignore", category=sa_exc.SAPendingDeprecationWarning
+ )
+ warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
+ warnings.filterwarnings("error", category=sa_exc.SAWarning)
+
+ warnings.filterwarnings("always", category=SATestSuiteWarning)
+
+ warnings.filterwarnings(
+ "error", category=DeprecationWarning, module=origin
+ )
+
+ # ignore things that are deprecated *as of* 2.0 :)
+ warnings.filterwarnings(
+ "ignore",
+ category=sa_exc.SADeprecationWarning,
+ message=r".*\(deprecated since: 2.0\)$",
+ )
+ warnings.filterwarnings(
+ "ignore",
+ category=sa_exc.SADeprecationWarning,
+ message=r"^The (Sybase|firebird) dialect is deprecated and will be",
+ )
+
+ try:
+ import pytest
+ except ImportError:
+ pass
+ else:
+ warnings.filterwarnings(
+ "once", category=pytest.PytestDeprecationWarning, module=origin
+ )
+
+
+def assert_warnings(fn, warning_msgs, regex=False):
+ """Assert that each of the given warnings are emitted by fn.
+
+ Deprecated. Please use assertions.expect_warnings().
+
+ """
+
+ with assertions._expect_warnings(
+ sa_exc.SAWarning, warning_msgs, regex=regex
+ ):
+ return fn()