first add files
This commit is contained in:
86
lib/sqlalchemy/testing/__init__.py
Normal file
86
lib/sqlalchemy/testing/__init__.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# testing/__init__.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
from . import config
|
||||
from . import mock
|
||||
from .assertions import assert_raises
|
||||
from .assertions import assert_raises_context_ok
|
||||
from .assertions import assert_raises_message
|
||||
from .assertions import assert_raises_message_context_ok
|
||||
from .assertions import assert_warns
|
||||
from .assertions import assert_warns_message
|
||||
from .assertions import AssertsCompiledSQL
|
||||
from .assertions import AssertsExecutionResults
|
||||
from .assertions import ComparesTables
|
||||
from .assertions import emits_warning
|
||||
from .assertions import emits_warning_on
|
||||
from .assertions import eq_
|
||||
from .assertions import eq_ignore_whitespace
|
||||
from .assertions import eq_regex
|
||||
from .assertions import expect_deprecated
|
||||
from .assertions import expect_deprecated_20
|
||||
from .assertions import expect_raises
|
||||
from .assertions import expect_raises_message
|
||||
from .assertions import expect_warnings
|
||||
from .assertions import in_
|
||||
from .assertions import is_
|
||||
from .assertions import is_false
|
||||
from .assertions import is_instance_of
|
||||
from .assertions import is_none
|
||||
from .assertions import is_not
|
||||
from .assertions import is_not_
|
||||
from .assertions import is_not_none
|
||||
from .assertions import is_true
|
||||
from .assertions import le_
|
||||
from .assertions import ne_
|
||||
from .assertions import not_in
|
||||
from .assertions import not_in_
|
||||
from .assertions import startswith_
|
||||
from .assertions import uses_deprecated
|
||||
from .config import async_test
|
||||
from .config import combinations
|
||||
from .config import combinations_list
|
||||
from .config import db
|
||||
from .config import fixture
|
||||
from .config import requirements as requires
|
||||
from .exclusions import _is_excluded
|
||||
from .exclusions import _server_version
|
||||
from .exclusions import against as _against
|
||||
from .exclusions import db_spec
|
||||
from .exclusions import exclude
|
||||
from .exclusions import fails
|
||||
from .exclusions import fails_if
|
||||
from .exclusions import fails_on
|
||||
from .exclusions import fails_on_everything_except
|
||||
from .exclusions import future
|
||||
from .exclusions import only_if
|
||||
from .exclusions import only_on
|
||||
from .exclusions import skip
|
||||
from .exclusions import skip_if
|
||||
from .schema import eq_clause_element
|
||||
from .schema import eq_type_affinity
|
||||
from .util import adict
|
||||
from .util import fail
|
||||
from .util import flag_combinations
|
||||
from .util import force_drop_names
|
||||
from .util import lambda_combinations
|
||||
from .util import metadata_fixture
|
||||
from .util import provide_metadata
|
||||
from .util import resolve_lambda
|
||||
from .util import rowset
|
||||
from .util import run_as_contextmanager
|
||||
from .util import teardown_events
|
||||
from .warnings import assert_warnings
|
||||
from .warnings import warn_test_suite
|
||||
|
||||
|
||||
def against(*queries):
|
||||
return _against(config._current, *queries)
|
||||
|
||||
|
||||
crashes = skip
|
||||
845
lib/sqlalchemy/testing/assertions.py
Normal file
845
lib/sqlalchemy/testing/assertions.py
Normal file
@@ -0,0 +1,845 @@
|
||||
# testing/assertions.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from . import assertsql
|
||||
from . import config
|
||||
from . import engines
|
||||
from . import mock
|
||||
from .exclusions import db_spec
|
||||
from .util import fail
|
||||
from .. import exc as sa_exc
|
||||
from .. import schema
|
||||
from .. import sql
|
||||
from .. import types as sqltypes
|
||||
from .. import util
|
||||
from ..engine import default
|
||||
from ..engine import url
|
||||
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
from ..util import compat
|
||||
from ..util import decorator
|
||||
|
||||
|
||||
def expect_warnings(*messages, **kw):
|
||||
"""Context manager which expects one or more warnings.
|
||||
|
||||
With no arguments, squelches all SAWarning and RemovedIn20Warning emitted via
|
||||
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
|
||||
pass string expressions that will match selected warnings via regex;
|
||||
all non-matching warnings are sent through.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
Note that the test suite sets SAWarning warnings to raise exceptions.
|
||||
|
||||
""" # noqa
|
||||
return _expect_warnings(
|
||||
(sa_exc.RemovedIn20Warning, sa_exc.SAWarning), messages, **kw
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def expect_warnings_on(db, *messages, **kw):
|
||||
"""Context manager which expects one or more warnings on specific
|
||||
dialects.
|
||||
|
||||
The expect version **asserts** that the warnings were in fact seen.
|
||||
|
||||
"""
|
||||
spec = db_spec(db)
|
||||
|
||||
if isinstance(db, util.string_types) and not spec(config._current):
|
||||
yield
|
||||
else:
|
||||
with expect_warnings(*messages, **kw):
|
||||
yield
|
||||
|
||||
|
||||
def emits_warning(*messages):
|
||||
"""Decorator form of expect_warnings().
|
||||
|
||||
Note that emits_warning does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings(assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def expect_deprecated(*messages, **kw):
|
||||
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
def expect_deprecated_20(*messages, **kw):
|
||||
return _expect_warnings(sa_exc.Base20DeprecationWarning, messages, **kw)
|
||||
|
||||
|
||||
def emits_warning_on(db, *messages):
|
||||
"""Mark a test as emitting a warning on a specific dialect.
|
||||
|
||||
With no arguments, squelches all SAWarning failures. Or pass one or more
|
||||
strings; these will be matched to the root of the warning description by
|
||||
warnings.filterwarnings().
|
||||
|
||||
Note that emits_warning_on does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_warnings_on(db, assert_=False, *messages):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def uses_deprecated(*messages):
|
||||
"""Mark a test as immune from fatal deprecation warnings.
|
||||
|
||||
With no arguments, squelches all SADeprecationWarning failures.
|
||||
Or pass one or more strings; these will be matched to the root
|
||||
of the warning description by warnings.filterwarnings().
|
||||
|
||||
As a special case, you may pass a function name prefixed with //
|
||||
and it will be re-written as needed to match the standard warning
|
||||
verbiage emitted by the sqlalchemy.util.deprecated decorator.
|
||||
|
||||
Note that uses_deprecated does **not** assert that the warnings
|
||||
were in fact seen.
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
with expect_deprecated(*messages, assert_=False):
|
||||
return fn(*args, **kw)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
_FILTERS = None
|
||||
_SEEN = None
|
||||
_EXC_CLS = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_warnings(
|
||||
exc_cls,
|
||||
messages,
|
||||
regex=True,
|
||||
search_msg=False,
|
||||
assert_=True,
|
||||
py2konly=False,
|
||||
raise_on_any_unexpected=False,
|
||||
squelch_other_warnings=False,
|
||||
):
|
||||
|
||||
global _FILTERS, _SEEN, _EXC_CLS
|
||||
|
||||
if regex or search_msg:
|
||||
filters = [re.compile(msg, re.I | re.S) for msg in messages]
|
||||
else:
|
||||
filters = list(messages)
|
||||
|
||||
if _FILTERS is not None:
|
||||
# nested call; update _FILTERS and _SEEN, return. outer
|
||||
# block will assert our messages
|
||||
assert _SEEN is not None
|
||||
assert _EXC_CLS is not None
|
||||
_FILTERS.extend(filters)
|
||||
_SEEN.update(filters)
|
||||
_EXC_CLS += (exc_cls,)
|
||||
yield
|
||||
else:
|
||||
seen = _SEEN = set(filters)
|
||||
_FILTERS = filters
|
||||
_EXC_CLS = (exc_cls,)
|
||||
|
||||
if raise_on_any_unexpected:
|
||||
|
||||
def real_warn(msg, *arg, **kw):
|
||||
raise AssertionError("Got unexpected warning: %r" % msg)
|
||||
|
||||
else:
|
||||
real_warn = warnings.warn
|
||||
|
||||
def our_warn(msg, *arg, **kw):
|
||||
|
||||
if isinstance(msg, _EXC_CLS):
|
||||
exception = type(msg)
|
||||
msg = str(msg)
|
||||
elif arg:
|
||||
exception = arg[0]
|
||||
else:
|
||||
exception = None
|
||||
|
||||
if not exception or not issubclass(exception, _EXC_CLS):
|
||||
if not squelch_other_warnings:
|
||||
return real_warn(msg, *arg, **kw)
|
||||
else:
|
||||
return
|
||||
|
||||
if not filters and not raise_on_any_unexpected:
|
||||
return
|
||||
|
||||
for filter_ in filters:
|
||||
if (
|
||||
(search_msg and filter_.search(msg))
|
||||
or (regex and filter_.match(msg))
|
||||
or (not regex and filter_ == msg)
|
||||
):
|
||||
seen.discard(filter_)
|
||||
break
|
||||
else:
|
||||
if not squelch_other_warnings:
|
||||
real_warn(msg, *arg, **kw)
|
||||
|
||||
with mock.patch("warnings.warn", our_warn), mock.patch(
|
||||
"sqlalchemy.util.SQLALCHEMY_WARN_20", True
|
||||
), mock.patch(
|
||||
"sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20", True
|
||||
), mock.patch(
|
||||
"sqlalchemy.engine.row.LegacyRow._default_key_style", 2
|
||||
):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_SEEN = _FILTERS = _EXC_CLS = None
|
||||
|
||||
if assert_ and (not py2konly or not compat.py3k):
|
||||
assert not seen, "Warnings were not seen: %s" % ", ".join(
|
||||
"%r" % (s.pattern if regex else s) for s in seen
|
||||
)
|
||||
|
||||
|
||||
def global_cleanup_assertions():
|
||||
"""Check things that have to be finalized at the end of a test suite.
|
||||
|
||||
Hardcoded at the moment, a modular system can be built here
|
||||
to support things like PG prepared transactions, tables all
|
||||
dropped, etc.
|
||||
|
||||
"""
|
||||
_assert_no_stray_pool_connections()
|
||||
|
||||
|
||||
def _assert_no_stray_pool_connections():
|
||||
engines.testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
def eq_regex(a, b, msg=None):
|
||||
assert re.match(b, a), msg or "%r !~ %r" % (a, b)
|
||||
|
||||
|
||||
def eq_(a, b, msg=None):
|
||||
"""Assert a == b, with repr messaging on failure."""
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def ne_(a, b, msg=None):
|
||||
"""Assert a != b, with repr messaging on failure."""
|
||||
assert a != b, msg or "%r == %r" % (a, b)
|
||||
|
||||
|
||||
def le_(a, b, msg=None):
|
||||
"""Assert a <= b, with repr messaging on failure."""
|
||||
assert a <= b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def is_instance_of(a, b, msg=None):
|
||||
assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
|
||||
|
||||
|
||||
def is_none(a, msg=None):
|
||||
is_(a, None, msg=msg)
|
||||
|
||||
|
||||
def is_not_none(a, msg=None):
|
||||
is_not(a, None, msg=msg)
|
||||
|
||||
|
||||
def is_true(a, msg=None):
|
||||
is_(bool(a), True, msg=msg)
|
||||
|
||||
|
||||
def is_false(a, msg=None):
|
||||
is_(bool(a), False, msg=msg)
|
||||
|
||||
|
||||
def is_(a, b, msg=None):
|
||||
"""Assert a is b, with repr messaging on failure."""
|
||||
assert a is b, msg or "%r is not %r" % (a, b)
|
||||
|
||||
|
||||
def is_not(a, b, msg=None):
|
||||
"""Assert a is not b, with repr messaging on failure."""
|
||||
assert a is not b, msg or "%r is %r" % (a, b)
|
||||
|
||||
|
||||
# deprecated. See #5429
|
||||
is_not_ = is_not
|
||||
|
||||
|
||||
def in_(a, b, msg=None):
|
||||
"""Assert a in b, with repr messaging on failure."""
|
||||
assert a in b, msg or "%r not in %r" % (a, b)
|
||||
|
||||
|
||||
def not_in(a, b, msg=None):
|
||||
"""Assert a in not b, with repr messaging on failure."""
|
||||
assert a not in b, msg or "%r is in %r" % (a, b)
|
||||
|
||||
|
||||
# deprecated. See #5429
|
||||
not_in_ = not_in
|
||||
|
||||
|
||||
def startswith_(a, fragment, msg=None):
|
||||
"""Assert a.startswith(fragment), with repr messaging on failure."""
|
||||
assert a.startswith(fragment), msg or "%r does not start with %r" % (
|
||||
a,
|
||||
fragment,
|
||||
)
|
||||
|
||||
|
||||
def eq_ignore_whitespace(a, b, msg=None):
|
||||
a = re.sub(r"^\s+?|\n", "", a)
|
||||
a = re.sub(r" {2,}", " ", a)
|
||||
b = re.sub(r"^\s+?|\n", "", b)
|
||||
b = re.sub(r" {2,}", " ", b)
|
||||
|
||||
assert a == b, msg or "%r != %r" % (a, b)
|
||||
|
||||
|
||||
def _assert_proper_exception_context(exception):
|
||||
"""assert that any exception we're catching does not have a __context__
|
||||
without a __cause__, and that __suppress_context__ is never set.
|
||||
|
||||
Python 3 will report nested as exceptions as "during the handling of
|
||||
error X, error Y occurred". That's not what we want to do. we want
|
||||
these exceptions in a cause chain.
|
||||
|
||||
"""
|
||||
|
||||
if not util.py3k:
|
||||
return
|
||||
|
||||
if (
|
||||
exception.__context__ is not exception.__cause__
|
||||
and not exception.__suppress_context__
|
||||
):
|
||||
assert False, (
|
||||
"Exception %r was correctly raised but did not set a cause, "
|
||||
"within context %r as its cause."
|
||||
% (exception, exception.__context__)
|
||||
)
|
||||
|
||||
|
||||
def assert_raises(except_cls, callable_, *args, **kw):
|
||||
return _assert_raises(except_cls, callable_, args, kw, check_context=True)
|
||||
|
||||
|
||||
def assert_raises_context_ok(except_cls, callable_, *args, **kw):
|
||||
return _assert_raises(except_cls, callable_, args, kw)
|
||||
|
||||
|
||||
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
return _assert_raises(
|
||||
except_cls, callable_, args, kwargs, msg=msg, check_context=True
|
||||
)
|
||||
|
||||
|
||||
def assert_warns(except_cls, callable_, *args, **kwargs):
|
||||
"""legacy adapter function for functions that were previously using
|
||||
assert_raises with SAWarning or similar.
|
||||
|
||||
has some workarounds to accommodate the fact that the callable completes
|
||||
with this approach rather than stopping at the exception raise.
|
||||
|
||||
|
||||
"""
|
||||
with _expect_warnings(except_cls, [".*"], squelch_other_warnings=True):
|
||||
return callable_(*args, **kwargs)
|
||||
|
||||
|
||||
def assert_warns_message(except_cls, msg, callable_, *args, **kwargs):
|
||||
"""legacy adapter function for functions that were previously using
|
||||
assert_raises with SAWarning or similar.
|
||||
|
||||
has some workarounds to accommodate the fact that the callable completes
|
||||
with this approach rather than stopping at the exception raise.
|
||||
|
||||
Also uses regex.search() to match the given message to the error string
|
||||
rather than regex.match().
|
||||
|
||||
"""
|
||||
with _expect_warnings(
|
||||
except_cls,
|
||||
[msg],
|
||||
search_msg=True,
|
||||
regex=False,
|
||||
squelch_other_warnings=True,
|
||||
):
|
||||
return callable_(*args, **kwargs)
|
||||
|
||||
|
||||
def assert_raises_message_context_ok(
|
||||
except_cls, msg, callable_, *args, **kwargs
|
||||
):
|
||||
return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
|
||||
|
||||
|
||||
def _assert_raises(
|
||||
except_cls, callable_, args, kwargs, msg=None, check_context=False
|
||||
):
|
||||
|
||||
with _expect_raises(except_cls, msg, check_context) as ec:
|
||||
callable_(*args, **kwargs)
|
||||
return ec.error
|
||||
|
||||
|
||||
class _ErrorContainer(object):
|
||||
error = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _expect_raises(except_cls, msg=None, check_context=False):
|
||||
if (
|
||||
isinstance(except_cls, type)
|
||||
and issubclass(except_cls, Warning)
|
||||
or isinstance(except_cls, Warning)
|
||||
):
|
||||
raise TypeError(
|
||||
"Use expect_warnings for warnings, not "
|
||||
"expect_raises / assert_raises"
|
||||
)
|
||||
ec = _ErrorContainer()
|
||||
if check_context:
|
||||
are_we_already_in_a_traceback = sys.exc_info()[0]
|
||||
try:
|
||||
yield ec
|
||||
success = False
|
||||
except except_cls as err:
|
||||
ec.error = err
|
||||
success = True
|
||||
if msg is not None:
|
||||
assert re.search(
|
||||
msg, util.text_type(err), re.UNICODE
|
||||
), "%r !~ %s" % (msg, err)
|
||||
if check_context and not are_we_already_in_a_traceback:
|
||||
_assert_proper_exception_context(err)
|
||||
print(util.text_type(err).encode("utf-8"))
|
||||
|
||||
# it's generally a good idea to not carry traceback objects outside
|
||||
# of the except: block, but in this case especially we seem to have
|
||||
# hit some bug in either python 3.10.0b2 or greenlet or both which
|
||||
# this seems to fix:
|
||||
# https://github.com/python-greenlet/greenlet/issues/242
|
||||
del ec
|
||||
|
||||
# assert outside the block so it works for AssertionError too !
|
||||
assert success, "Callable did not raise an exception"
|
||||
|
||||
|
||||
def expect_raises(except_cls, check_context=True):
|
||||
return _expect_raises(except_cls, check_context=check_context)
|
||||
|
||||
|
||||
def expect_raises_message(except_cls, msg, check_context=True):
|
||||
return _expect_raises(except_cls, msg=msg, check_context=check_context)
|
||||
|
||||
|
||||
class AssertsCompiledSQL(object):
|
||||
def assert_compile(
|
||||
self,
|
||||
clause,
|
||||
result,
|
||||
params=None,
|
||||
checkparams=None,
|
||||
for_executemany=False,
|
||||
check_literal_execute=None,
|
||||
check_post_param=None,
|
||||
dialect=None,
|
||||
checkpositional=None,
|
||||
check_prefetch=None,
|
||||
use_default_dialect=False,
|
||||
allow_dialect_select=False,
|
||||
supports_default_values=True,
|
||||
supports_default_metavalue=True,
|
||||
literal_binds=False,
|
||||
render_postcompile=False,
|
||||
schema_translate_map=None,
|
||||
render_schema_translate=False,
|
||||
default_schema_name=None,
|
||||
from_linting=False,
|
||||
):
|
||||
if use_default_dialect:
|
||||
dialect = default.DefaultDialect()
|
||||
dialect.supports_default_values = supports_default_values
|
||||
dialect.supports_default_metavalue = supports_default_metavalue
|
||||
elif allow_dialect_select:
|
||||
dialect = None
|
||||
else:
|
||||
if dialect is None:
|
||||
dialect = getattr(self, "__dialect__", None)
|
||||
|
||||
if dialect is None:
|
||||
dialect = config.db.dialect
|
||||
elif dialect == "default":
|
||||
dialect = default.DefaultDialect()
|
||||
dialect.supports_default_values = supports_default_values
|
||||
dialect.supports_default_metavalue = supports_default_metavalue
|
||||
elif dialect == "default_enhanced":
|
||||
dialect = default.StrCompileDialect()
|
||||
elif isinstance(dialect, util.string_types):
|
||||
dialect = url.URL.create(dialect).get_dialect()()
|
||||
|
||||
if default_schema_name:
|
||||
dialect.default_schema_name = default_schema_name
|
||||
|
||||
kw = {}
|
||||
compile_kwargs = {}
|
||||
|
||||
if schema_translate_map:
|
||||
kw["schema_translate_map"] = schema_translate_map
|
||||
|
||||
if params is not None:
|
||||
kw["column_keys"] = list(params)
|
||||
|
||||
if literal_binds:
|
||||
compile_kwargs["literal_binds"] = True
|
||||
|
||||
if render_postcompile:
|
||||
compile_kwargs["render_postcompile"] = True
|
||||
|
||||
if for_executemany:
|
||||
kw["for_executemany"] = True
|
||||
|
||||
if render_schema_translate:
|
||||
kw["render_schema_translate"] = True
|
||||
|
||||
if from_linting or getattr(self, "assert_from_linting", False):
|
||||
kw["linting"] = sql.FROM_LINTING
|
||||
|
||||
from sqlalchemy import orm
|
||||
|
||||
if isinstance(clause, orm.Query):
|
||||
stmt = clause._statement_20()
|
||||
stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
clause = stmt
|
||||
|
||||
if compile_kwargs:
|
||||
kw["compile_kwargs"] = compile_kwargs
|
||||
|
||||
class DontAccess(object):
|
||||
def __getattribute__(self, key):
|
||||
raise NotImplementedError(
|
||||
"compiler accessed .statement; use "
|
||||
"compiler.current_executable"
|
||||
)
|
||||
|
||||
class CheckCompilerAccess(object):
|
||||
def __init__(self, test_statement):
|
||||
self.test_statement = test_statement
|
||||
self._annotations = {}
|
||||
self.supports_execution = getattr(
|
||||
test_statement, "supports_execution", False
|
||||
)
|
||||
|
||||
if self.supports_execution:
|
||||
self._execution_options = test_statement._execution_options
|
||||
|
||||
if hasattr(test_statement, "_returning"):
|
||||
self._returning = test_statement._returning
|
||||
if hasattr(test_statement, "_inline"):
|
||||
self._inline = test_statement._inline
|
||||
if hasattr(test_statement, "_return_defaults"):
|
||||
self._return_defaults = test_statement._return_defaults
|
||||
|
||||
def _default_dialect(self):
|
||||
return self.test_statement._default_dialect()
|
||||
|
||||
def compile(self, dialect, **kw):
|
||||
return self.test_statement.compile.__func__(
|
||||
self, dialect=dialect, **kw
|
||||
)
|
||||
|
||||
def _compiler(self, dialect, **kw):
|
||||
return self.test_statement._compiler.__func__(
|
||||
self, dialect, **kw
|
||||
)
|
||||
|
||||
def _compiler_dispatch(self, compiler, **kwargs):
|
||||
if hasattr(compiler, "statement"):
|
||||
with mock.patch.object(
|
||||
compiler, "statement", DontAccess()
|
||||
):
|
||||
return self.test_statement._compiler_dispatch(
|
||||
compiler, **kwargs
|
||||
)
|
||||
else:
|
||||
return self.test_statement._compiler_dispatch(
|
||||
compiler, **kwargs
|
||||
)
|
||||
|
||||
# no construct can assume it's the "top level" construct in all cases
|
||||
# as anything can be nested. ensure constructs don't assume they
|
||||
# are the "self.statement" element
|
||||
c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
|
||||
|
||||
if isinstance(clause, sqltypes.TypeEngine):
|
||||
cache_key_no_warnings = clause._static_cache_key
|
||||
if cache_key_no_warnings:
|
||||
hash(cache_key_no_warnings)
|
||||
else:
|
||||
cache_key_no_warnings = clause._generate_cache_key()
|
||||
if cache_key_no_warnings:
|
||||
hash(cache_key_no_warnings[0])
|
||||
|
||||
param_str = repr(getattr(c, "params", {}))
|
||||
if util.py3k:
|
||||
param_str = param_str.encode("utf-8").decode("ascii", "ignore")
|
||||
print(
|
||||
("\nSQL String:\n" + util.text_type(c) + param_str).encode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"\nSQL String:\n"
|
||||
+ util.text_type(c).encode("utf-8")
|
||||
+ param_str
|
||||
)
|
||||
|
||||
cc = re.sub(r"[\n\t]", "", util.text_type(c))
|
||||
|
||||
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
|
||||
|
||||
if checkparams is not None:
|
||||
eq_(c.construct_params(params), checkparams)
|
||||
if checkpositional is not None:
|
||||
p = c.construct_params(params)
|
||||
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
|
||||
if check_prefetch is not None:
|
||||
eq_(c.prefetch, check_prefetch)
|
||||
if check_literal_execute is not None:
|
||||
eq_(
|
||||
{
|
||||
c.bind_names[b]: b.effective_value
|
||||
for b in c.literal_execute_params
|
||||
},
|
||||
check_literal_execute,
|
||||
)
|
||||
if check_post_param is not None:
|
||||
eq_(
|
||||
{
|
||||
c.bind_names[b]: b.effective_value
|
||||
for b in c.post_compile_params
|
||||
},
|
||||
check_post_param,
|
||||
)
|
||||
|
||||
|
||||
class ComparesTables(object):
|
||||
def assert_tables_equal(self, table, reflected_table, strict_types=False):
|
||||
assert len(table.c) == len(reflected_table.c)
|
||||
for c, reflected_c in zip(table.c, reflected_table.c):
|
||||
eq_(c.name, reflected_c.name)
|
||||
assert reflected_c is reflected_table.c[c.name]
|
||||
eq_(c.primary_key, reflected_c.primary_key)
|
||||
eq_(c.nullable, reflected_c.nullable)
|
||||
|
||||
if strict_types:
|
||||
msg = "Type '%s' doesn't correspond to type '%s'"
|
||||
assert isinstance(reflected_c.type, type(c.type)), msg % (
|
||||
reflected_c.type,
|
||||
c.type,
|
||||
)
|
||||
else:
|
||||
self.assert_types_base(reflected_c, c)
|
||||
|
||||
if isinstance(c.type, sqltypes.String):
|
||||
eq_(c.type.length, reflected_c.type.length)
|
||||
|
||||
eq_(
|
||||
{f.column.name for f in c.foreign_keys},
|
||||
{f.column.name for f in reflected_c.foreign_keys},
|
||||
)
|
||||
if c.server_default:
|
||||
assert isinstance(
|
||||
reflected_c.server_default, schema.FetchedValue
|
||||
)
|
||||
|
||||
assert len(table.primary_key) == len(reflected_table.primary_key)
|
||||
for c in table.primary_key:
|
||||
assert reflected_table.primary_key.columns[c.name] is not None
|
||||
|
||||
def assert_types_base(self, c1, c2):
|
||||
assert c1.type._compare_type_affinity(
|
||||
c2.type
|
||||
), "On column %r, type '%s' doesn't correspond to type '%s'" % (
|
||||
c1.name,
|
||||
c1.type,
|
||||
c2.type,
|
||||
)
|
||||
|
||||
|
||||
class AssertsExecutionResults(object):
|
||||
def assert_result(self, result, class_, *objects):
|
||||
result = list(result)
|
||||
print(repr(result))
|
||||
self.assert_list(result, class_, objects)
|
||||
|
||||
def assert_list(self, result, class_, list_):
|
||||
self.assert_(
|
||||
len(result) == len(list_),
|
||||
"result list is not the same size as test list, "
|
||||
+ "for class "
|
||||
+ class_.__name__,
|
||||
)
|
||||
for i in range(0, len(list_)):
|
||||
self.assert_row(class_, result[i], list_[i])
|
||||
|
||||
def assert_row(self, class_, rowobj, desc):
|
||||
self.assert_(
|
||||
rowobj.__class__ is class_, "item class is not " + repr(class_)
|
||||
)
|
||||
for key, value in desc.items():
|
||||
if isinstance(value, tuple):
|
||||
if isinstance(value[1], list):
|
||||
self.assert_list(getattr(rowobj, key), value[0], value[1])
|
||||
else:
|
||||
self.assert_row(value[0], getattr(rowobj, key), value[1])
|
||||
else:
|
||||
self.assert_(
|
||||
getattr(rowobj, key) == value,
|
||||
"attribute %s value %s does not match %s"
|
||||
% (key, getattr(rowobj, key), value),
|
||||
)
|
||||
|
||||
def assert_unordered_result(self, result, cls, *expected):
|
||||
"""As assert_result, but the order of objects is not considered.
|
||||
|
||||
The algorithm is very expensive but not a big deal for the small
|
||||
numbers of rows that the test suite manipulates.
|
||||
"""
|
||||
|
||||
class immutabledict(dict):
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
found = util.IdentitySet(result)
|
||||
expected = {immutabledict(e) for e in expected}
|
||||
|
||||
for wrong in util.itertools_filterfalse(
|
||||
lambda o: isinstance(o, cls), found
|
||||
):
|
||||
fail(
|
||||
'Unexpected type "%s", expected "%s"'
|
||||
% (type(wrong).__name__, cls.__name__)
|
||||
)
|
||||
|
||||
if len(found) != len(expected):
|
||||
fail(
|
||||
'Unexpected object count "%s", expected "%s"'
|
||||
% (len(found), len(expected))
|
||||
)
|
||||
|
||||
NOVALUE = object()
|
||||
|
||||
def _compare_item(obj, spec):
|
||||
for key, value in spec.items():
|
||||
if isinstance(value, tuple):
|
||||
try:
|
||||
self.assert_unordered_result(
|
||||
getattr(obj, key), value[0], *value[1]
|
||||
)
|
||||
except AssertionError:
|
||||
return False
|
||||
else:
|
||||
if getattr(obj, key, NOVALUE) != value:
|
||||
return False
|
||||
return True
|
||||
|
||||
for expected_item in expected:
|
||||
for found_item in found:
|
||||
if _compare_item(found_item, expected_item):
|
||||
found.remove(found_item)
|
||||
break
|
||||
else:
|
||||
fail(
|
||||
"Expected %s instance with attributes %s not found."
|
||||
% (cls.__name__, repr(expected_item))
|
||||
)
|
||||
return True
|
||||
|
||||
def sql_execution_asserter(self, db=None):
|
||||
if db is None:
|
||||
from . import db as db
|
||||
|
||||
return assertsql.assert_engine(db)
|
||||
|
||||
def assert_sql_execution(self, db, callable_, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
result = callable_()
|
||||
asserter.assert_(*rules)
|
||||
return result
|
||||
|
||||
def assert_sql(self, db, callable_, rules):
|
||||
|
||||
newrules = []
|
||||
for rule in rules:
|
||||
if isinstance(rule, dict):
|
||||
newrule = assertsql.AllOf(
|
||||
*[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
|
||||
)
|
||||
else:
|
||||
newrule = assertsql.CompiledSQL(*rule)
|
||||
newrules.append(newrule)
|
||||
|
||||
return self.assert_sql_execution(db, callable_, *newrules)
|
||||
|
||||
def assert_sql_count(self, db, callable_, count):
|
||||
self.assert_sql_execution(
|
||||
db, callable_, assertsql.CountStatements(count)
|
||||
)
|
||||
|
||||
def assert_multiple_sql_count(self, dbs, callable_, counts):
|
||||
recs = [
|
||||
(self.sql_execution_asserter(db), db, count)
|
||||
for (db, count) in zip(dbs, counts)
|
||||
]
|
||||
asserters = []
|
||||
for ctx, db, count in recs:
|
||||
asserters.append(ctx.__enter__())
|
||||
try:
|
||||
return callable_()
|
||||
finally:
|
||||
for asserter, (ctx, db, count) in zip(asserters, recs):
|
||||
ctx.__exit__(None, None, None)
|
||||
asserter.assert_(assertsql.CountStatements(count))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_execution(self, db, *rules):
|
||||
with self.sql_execution_asserter(db) as asserter:
|
||||
yield
|
||||
asserter.assert_(*rules)
|
||||
|
||||
def assert_statement_count(self, db, count):
|
||||
return self.assert_execution(db, assertsql.CountStatements(count))
|
||||
457
lib/sqlalchemy/testing/assertsql.py
Normal file
457
lib/sqlalchemy/testing/assertsql.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# testing/assertsql.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import re
|
||||
|
||||
from .. import event
|
||||
from .. import util
|
||||
from ..engine import url
|
||||
from ..engine.default import DefaultDialect
|
||||
from ..engine.util import _distill_cursor_params
|
||||
from ..schema import _DDLCompiles
|
||||
|
||||
|
||||
class AssertRule(object):
|
||||
|
||||
is_consumed = False
|
||||
errormessage = None
|
||||
consume_statement = True
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
pass
|
||||
|
||||
def no_more_statements(self):
|
||||
assert False, (
|
||||
"All statements are complete, but pending "
|
||||
"assertion rules remain"
|
||||
)
|
||||
|
||||
|
||||
class SQLMatchRule(AssertRule):
|
||||
pass
|
||||
|
||||
|
||||
class CursorSQL(SQLMatchRule):
|
||||
def __init__(self, statement, params=None, consume_statement=True):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.consume_statement = consume_statement
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
stmt = execute_observed.statements[0]
|
||||
if self.statement != stmt.statement or (
|
||||
self.params is not None and self.params != stmt.parameters
|
||||
):
|
||||
self.errormessage = (
|
||||
"Testing for exact SQL %s parameters %s received %s %s"
|
||||
% (
|
||||
self.statement,
|
||||
self.params,
|
||||
stmt.statement,
|
||||
stmt.parameters,
|
||||
)
|
||||
)
|
||||
else:
|
||||
execute_observed.statements.pop(0)
|
||||
self.is_consumed = True
|
||||
if not execute_observed.statements:
|
||||
self.consume_statement = True
|
||||
|
||||
|
||||
class CompiledSQL(SQLMatchRule):
|
||||
def __init__(self, statement, params=None, dialect="default"):
|
||||
self.statement = statement
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r"[\n\t]", "", self.statement)
|
||||
return received_statement == stmt
|
||||
|
||||
def _compile_dialect(self, execute_observed):
|
||||
if self.dialect == "default":
|
||||
dialect = DefaultDialect()
|
||||
# this is currently what tests are expecting
|
||||
# dialect.supports_default_values = True
|
||||
dialect.supports_default_metavalue = True
|
||||
return dialect
|
||||
else:
|
||||
# ugh
|
||||
if self.dialect == "postgresql":
|
||||
params = {"implicit_returning": True}
|
||||
else:
|
||||
params = {}
|
||||
return url.URL.create(self.dialect).get_dialect()(**params)
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
"""reconstruct the statement and params in terms
|
||||
of a target dialect, which for CompiledSQL is just DefaultDialect."""
|
||||
|
||||
context = execute_observed.context
|
||||
compare_dialect = self._compile_dialect(execute_observed)
|
||||
|
||||
# received_statement runs a full compile(). we should not need to
|
||||
# consider extracted_parameters; if we do this indicates some state
|
||||
# is being sent from a previous cached query, which some misbehaviors
|
||||
# in the ORM can cause, see #6881
|
||||
cache_key = None # execute_observed.context.compiled.cache_key
|
||||
extracted_parameters = (
|
||||
None # execute_observed.context.extracted_parameters
|
||||
)
|
||||
|
||||
if "schema_translate_map" in context.execution_options:
|
||||
map_ = context.execution_options["schema_translate_map"]
|
||||
else:
|
||||
map_ = None
|
||||
|
||||
if isinstance(execute_observed.clauseelement, _DDLCompiles):
|
||||
|
||||
compiled = execute_observed.clauseelement.compile(
|
||||
dialect=compare_dialect,
|
||||
schema_translate_map=map_,
|
||||
)
|
||||
else:
|
||||
compiled = execute_observed.clauseelement.compile(
|
||||
cache_key=cache_key,
|
||||
dialect=compare_dialect,
|
||||
column_keys=context.compiled.column_keys,
|
||||
for_executemany=context.compiled.for_executemany,
|
||||
schema_translate_map=map_,
|
||||
)
|
||||
_received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
|
||||
parameters = execute_observed.parameters
|
||||
|
||||
if not parameters:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(
|
||||
extracted_parameters=extracted_parameters
|
||||
)
|
||||
]
|
||||
else:
|
||||
_received_parameters = [
|
||||
compiled.construct_params(
|
||||
m, extracted_parameters=extracted_parameters
|
||||
)
|
||||
for m in parameters
|
||||
]
|
||||
|
||||
return _received_statement, _received_parameters
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
context = execute_observed.context
|
||||
|
||||
_received_statement, _received_parameters = self._received_statement(
|
||||
execute_observed
|
||||
)
|
||||
params = self._all_params(context)
|
||||
|
||||
equivalent = self._compare_sql(execute_observed, _received_statement)
|
||||
|
||||
if equivalent:
|
||||
if params is not None:
|
||||
all_params = list(params)
|
||||
all_received = list(_received_parameters)
|
||||
while all_params and all_received:
|
||||
param = dict(all_params.pop(0))
|
||||
|
||||
for idx, received in enumerate(list(all_received)):
|
||||
# do a positive compare only
|
||||
for param_key in param:
|
||||
# a key in param did not match current
|
||||
# 'received'
|
||||
if (
|
||||
param_key not in received
|
||||
or received[param_key] != param[param_key]
|
||||
):
|
||||
break
|
||||
else:
|
||||
# all keys in param matched 'received';
|
||||
# onto next param
|
||||
del all_received[idx]
|
||||
break
|
||||
else:
|
||||
# param did not match any entry
|
||||
# in all_received
|
||||
equivalent = False
|
||||
break
|
||||
if all_params or all_received:
|
||||
equivalent = False
|
||||
|
||||
if equivalent:
|
||||
self.is_consumed = True
|
||||
self.errormessage = None
|
||||
else:
|
||||
self.errormessage = self._failure_message(params) % {
|
||||
"received_statement": _received_statement,
|
||||
"received_parameters": _received_parameters,
|
||||
}
|
||||
|
||||
def _all_params(self, context):
|
||||
if self.params:
|
||||
if callable(self.params):
|
||||
params = self.params(context)
|
||||
else:
|
||||
params = self.params
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
return params
|
||||
else:
|
||||
return None
|
||||
|
||||
def _failure_message(self, expected_params):
|
||||
return (
|
||||
"Testing for compiled statement\n%r partial params %s, "
|
||||
"received\n%%(received_statement)r with params "
|
||||
"%%(received_parameters)r"
|
||||
% (
|
||||
self.statement.replace("%", "%%"),
|
||||
repr(expected_params).replace("%", "%%"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RegexSQL(CompiledSQL):
|
||||
def __init__(self, regex, params=None, dialect="default"):
|
||||
SQLMatchRule.__init__(self)
|
||||
self.regex = re.compile(regex)
|
||||
self.orig_regex = regex
|
||||
self.params = params
|
||||
self.dialect = dialect
|
||||
|
||||
def _failure_message(self, expected_params):
|
||||
return (
|
||||
"Testing for compiled statement ~%r partial params %s, "
|
||||
"received %%(received_statement)r with params "
|
||||
"%%(received_parameters)r"
|
||||
% (
|
||||
self.orig_regex.replace("%", "%%"),
|
||||
repr(expected_params).replace("%", "%%"),
|
||||
)
|
||||
)
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
return bool(self.regex.match(received_statement))
|
||||
|
||||
|
||||
class DialectSQL(CompiledSQL):
|
||||
def _compile_dialect(self, execute_observed):
|
||||
return execute_observed.context.dialect
|
||||
|
||||
def _compare_no_space(self, real_stmt, received_stmt):
|
||||
stmt = re.sub(r"[\n\t]", "", real_stmt)
|
||||
return received_stmt == stmt
|
||||
|
||||
def _received_statement(self, execute_observed):
|
||||
received_stmt, received_params = super(
|
||||
DialectSQL, self
|
||||
)._received_statement(execute_observed)
|
||||
|
||||
# TODO: why do we need this part?
|
||||
for real_stmt in execute_observed.statements:
|
||||
if self._compare_no_space(real_stmt.statement, received_stmt):
|
||||
break
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Can't locate compiled statement %r in list of "
|
||||
"statements actually invoked" % received_stmt
|
||||
)
|
||||
|
||||
return received_stmt, execute_observed.context.compiled_parameters
|
||||
|
||||
def _compare_sql(self, execute_observed, received_statement):
|
||||
stmt = re.sub(r"[\n\t]", "", self.statement)
|
||||
# convert our comparison statement to have the
|
||||
# paramstyle of the received
|
||||
paramstyle = execute_observed.context.dialect.paramstyle
|
||||
if paramstyle == "pyformat":
|
||||
stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
|
||||
else:
|
||||
# positional params
|
||||
repl = None
|
||||
if paramstyle == "qmark":
|
||||
repl = "?"
|
||||
elif paramstyle == "format":
|
||||
repl = r"%s"
|
||||
elif paramstyle == "numeric":
|
||||
repl = None
|
||||
stmt = re.sub(r":([\w_]+)", repl, stmt)
|
||||
|
||||
return received_statement == stmt
|
||||
|
||||
|
||||
class CountStatements(AssertRule):
|
||||
def __init__(self, count):
|
||||
self.count = count
|
||||
self._statement_count = 0
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
self._statement_count += 1
|
||||
|
||||
def no_more_statements(self):
|
||||
if self.count != self._statement_count:
|
||||
assert False, "desired statement count %d does not match %d" % (
|
||||
self.count,
|
||||
self._statement_count,
|
||||
)
|
||||
|
||||
|
||||
class AllOf(AssertRule):
|
||||
def __init__(self, *rules):
|
||||
self.rules = set(rules)
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in list(self.rules):
|
||||
rule.errormessage = None
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.discard(rule)
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
break
|
||||
elif not rule.errormessage:
|
||||
# rule is not done yet
|
||||
self.errormessage = None
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class EachOf(AssertRule):
|
||||
def __init__(self, *rules):
|
||||
self.rules = list(rules)
|
||||
|
||||
def process_statement(self, execute_observed):
|
||||
while self.rules:
|
||||
rule = self.rules[0]
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.rules.pop(0)
|
||||
elif rule.errormessage:
|
||||
self.errormessage = rule.errormessage
|
||||
if rule.consume_statement:
|
||||
break
|
||||
|
||||
if not self.rules:
|
||||
self.is_consumed = True
|
||||
|
||||
def no_more_statements(self):
|
||||
if self.rules and not self.rules[0].is_consumed:
|
||||
self.rules[0].no_more_statements()
|
||||
elif self.rules:
|
||||
super(EachOf, self).no_more_statements()
|
||||
|
||||
|
||||
class Conditional(EachOf):
|
||||
def __init__(self, condition, rules, else_rules):
|
||||
if condition:
|
||||
super(Conditional, self).__init__(*rules)
|
||||
else:
|
||||
super(Conditional, self).__init__(*else_rules)
|
||||
|
||||
|
||||
class Or(AllOf):
|
||||
def process_statement(self, execute_observed):
|
||||
for rule in self.rules:
|
||||
rule.process_statement(execute_observed)
|
||||
if rule.is_consumed:
|
||||
self.is_consumed = True
|
||||
break
|
||||
else:
|
||||
self.errormessage = list(self.rules)[0].errormessage
|
||||
|
||||
|
||||
class SQLExecuteObserved(object):
|
||||
def __init__(self, context, clauseelement, multiparams, params):
|
||||
self.context = context
|
||||
self.clauseelement = clauseelement
|
||||
self.parameters = _distill_cursor_params(
|
||||
context.connection, tuple(multiparams), params
|
||||
)
|
||||
self.statements = []
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.statements)
|
||||
|
||||
|
||||
class SQLCursorExecuteObserved(
|
||||
collections.namedtuple(
|
||||
"SQLCursorExecuteObserved",
|
||||
["statement", "parameters", "context", "executemany"],
|
||||
)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAsserter(object):
|
||||
def __init__(self):
|
||||
self.accumulated = []
|
||||
|
||||
def _close(self):
|
||||
self._final = self.accumulated
|
||||
del self.accumulated
|
||||
|
||||
def assert_(self, *rules):
|
||||
rule = EachOf(*rules)
|
||||
|
||||
observed = list(self._final)
|
||||
while observed:
|
||||
statement = observed.pop(0)
|
||||
rule.process_statement(statement)
|
||||
if rule.is_consumed:
|
||||
break
|
||||
elif rule.errormessage:
|
||||
assert False, rule.errormessage
|
||||
if observed:
|
||||
assert False, "Additional SQL statements remain:\n%s" % observed
|
||||
elif not rule.is_consumed:
|
||||
rule.no_more_statements()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def assert_engine(engine):
|
||||
asserter = SQLAsserter()
|
||||
|
||||
orig = []
|
||||
|
||||
@event.listens_for(engine, "before_execute")
|
||||
def connection_execute(
|
||||
conn, clauseelement, multiparams, params, execution_options
|
||||
):
|
||||
# grab the original statement + params before any cursor
|
||||
# execution
|
||||
orig[:] = clauseelement, multiparams, params
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def cursor_execute(
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
if not context:
|
||||
return
|
||||
# then grab real cursor statements and associate them all
|
||||
# around a single context
|
||||
if (
|
||||
asserter.accumulated
|
||||
and asserter.accumulated[-1].context is context
|
||||
):
|
||||
obs = asserter.accumulated[-1]
|
||||
else:
|
||||
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
|
||||
asserter.accumulated.append(obs)
|
||||
obs.statements.append(
|
||||
SQLCursorExecuteObserved(
|
||||
statement, parameters, context, executemany
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
yield asserter
|
||||
finally:
|
||||
event.remove(engine, "after_cursor_execute", cursor_execute)
|
||||
event.remove(engine, "before_execute", connection_execute)
|
||||
asserter._close()
|
||||
128
lib/sqlalchemy/testing/asyncio.py
Normal file
128
lib/sqlalchemy/testing/asyncio.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# testing/asyncio.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
# functions and wrappers to run tests, fixtures, provisioning and
|
||||
# setup/teardown in an asyncio event loop, conditionally based on the
|
||||
# current DB driver being used for a test.
|
||||
|
||||
# note that SQLAlchemy's asyncio integration also supports a method
|
||||
# of running individual asyncio functions inside of separate event loops
|
||||
# using "async_fallback" mode; however running whole functions in the event
|
||||
# loop is a more accurate test for how SQLAlchemy's asyncio features
|
||||
# would run in the real world.
|
||||
|
||||
|
||||
from functools import wraps
|
||||
import inspect
|
||||
|
||||
from . import config
|
||||
from ..util.concurrency import _util_async_run
|
||||
from ..util.concurrency import _util_async_run_coroutine_function
|
||||
|
||||
# may be set to False if the
|
||||
# --disable-asyncio flag is passed to the test runner.
|
||||
ENABLE_ASYNCIO = True
|
||||
|
||||
|
||||
def _run_coroutine_function(fn, *args, **kwargs):
|
||||
return _util_async_run_coroutine_function(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def _assume_async(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop unconditionally.
|
||||
|
||||
This function is used for provisioning features like
|
||||
testing a database connection for server info.
|
||||
|
||||
Note that for blocking IO database drivers, this means they block the
|
||||
event loop.
|
||||
|
||||
"""
|
||||
|
||||
if not ENABLE_ASYNCIO:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _util_async_run(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async_provisioning(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop if any current drivers might need it.
|
||||
|
||||
This function is used for provisioning features that take
|
||||
place outside of a specific database driver being selected, so if the
|
||||
current driver that happens to be used for the provisioning operation
|
||||
is an async driver, it will run in asyncio and not fail.
|
||||
|
||||
Note that for blocking IO database drivers, this means they block the
|
||||
event loop.
|
||||
|
||||
"""
|
||||
if not ENABLE_ASYNCIO:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
if config.any_async:
|
||||
return _util_async_run(fn, *args, **kwargs)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop if the current selected driver is
|
||||
async.
|
||||
|
||||
This function is used for test setup/teardown and tests themselves
|
||||
where the current DB driver is known.
|
||||
|
||||
|
||||
"""
|
||||
if not ENABLE_ASYNCIO:
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
is_async = config._current.is_async
|
||||
|
||||
if is_async:
|
||||
return _util_async_run(fn, *args, **kwargs)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
|
||||
def _maybe_async_wrapper(fn):
|
||||
"""Apply the _maybe_async function to an existing function and return
|
||||
as a wrapped callable, supporting generator functions as well.
|
||||
|
||||
This is currently used for pytest fixtures that support generator use.
|
||||
|
||||
"""
|
||||
|
||||
if inspect.isgeneratorfunction(fn):
|
||||
_stop = object()
|
||||
|
||||
def call_next(gen):
|
||||
try:
|
||||
return next(gen)
|
||||
# can't raise StopIteration in an awaitable.
|
||||
except StopIteration:
|
||||
return _stop
|
||||
|
||||
@wraps(fn)
|
||||
def wrap_fixture(*args, **kwargs):
|
||||
gen = fn(*args, **kwargs)
|
||||
while True:
|
||||
value = _maybe_async(call_next, gen)
|
||||
if value is _stop:
|
||||
break
|
||||
yield value
|
||||
|
||||
else:
|
||||
|
||||
@wraps(fn)
|
||||
def wrap_fixture(*args, **kwargs):
|
||||
return _maybe_async(fn, *args, **kwargs)
|
||||
|
||||
return wrap_fixture
|
||||
209
lib/sqlalchemy/testing/config.py
Normal file
209
lib/sqlalchemy/testing/config.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# testing/config.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import collections
|
||||
|
||||
from .. import util
|
||||
|
||||
requirements = None
|
||||
db = None
|
||||
db_url = None
|
||||
db_opts = None
|
||||
file_config = None
|
||||
test_schema = None
|
||||
test_schema_2 = None
|
||||
any_async = False
|
||||
_current = None
|
||||
ident = "main"
|
||||
|
||||
_fixture_functions = None # installed by plugin_base
|
||||
|
||||
|
||||
def combinations(*comb, **kw):
|
||||
r"""Deliver multiple versions of a test based on positional combinations.
|
||||
|
||||
This is a facade over pytest.mark.parametrize.
|
||||
|
||||
|
||||
:param \*comb: argument combinations. These are tuples that will be passed
|
||||
positionally to the decorated function.
|
||||
|
||||
:param argnames: optional list of argument names. These are the names
|
||||
of the arguments in the test function that correspond to the entries
|
||||
in each argument tuple. pytest.mark.parametrize requires this, however
|
||||
the combinations function will derive it automatically if not present
|
||||
by using ``inspect.getfullargspec(fn).args[1:]``. Note this assumes the
|
||||
first argument is "self" which is discarded.
|
||||
|
||||
:param id\_: optional id template. This is a string template that
|
||||
describes how the "id" for each parameter set should be defined, if any.
|
||||
The number of characters in the template should match the number of
|
||||
entries in each argument tuple. Each character describes how the
|
||||
corresponding entry in the argument tuple should be handled, as far as
|
||||
whether or not it is included in the arguments passed to the function, as
|
||||
well as if it is included in the tokens used to create the id of the
|
||||
parameter set.
|
||||
|
||||
If omitted, the argument combinations are passed to parametrize as is. If
|
||||
passed, each argument combination is turned into a pytest.param() object,
|
||||
mapping the elements of the argument tuple to produce an id based on a
|
||||
character value in the same position within the string template using the
|
||||
following scheme::
|
||||
|
||||
i - the given argument is a string that is part of the id only, don't
|
||||
pass it as an argument
|
||||
|
||||
n - the given argument should be passed and it should be added to the
|
||||
id by calling the .__name__ attribute
|
||||
|
||||
r - the given argument should be passed and it should be added to the
|
||||
id by calling repr()
|
||||
|
||||
s - the given argument should be passed and it should be added to the
|
||||
id by calling str()
|
||||
|
||||
a - (argument) the given argument should be passed and it should not
|
||||
be used to generated the id
|
||||
|
||||
e.g.::
|
||||
|
||||
@testing.combinations(
|
||||
(operator.eq, "eq"),
|
||||
(operator.ne, "ne"),
|
||||
(operator.gt, "gt"),
|
||||
(operator.lt, "lt"),
|
||||
id_="na"
|
||||
)
|
||||
def test_operator(self, opfunc, name):
|
||||
pass
|
||||
|
||||
The above combination will call ``.__name__`` on the first member of
|
||||
each tuple and use that as the "id" to pytest.param().
|
||||
|
||||
|
||||
"""
|
||||
return _fixture_functions.combinations(*comb, **kw)
|
||||
|
||||
|
||||
def combinations_list(arg_iterable, **kw):
|
||||
"As combination, but takes a single iterable"
|
||||
return combinations(*arg_iterable, **kw)
|
||||
|
||||
|
||||
def fixture(*arg, **kw):
|
||||
return _fixture_functions.fixture(*arg, **kw)
|
||||
|
||||
|
||||
def get_current_test_name():
|
||||
return _fixture_functions.get_current_test_name()
|
||||
|
||||
|
||||
def mark_base_test_class():
|
||||
return _fixture_functions.mark_base_test_class()
|
||||
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, db, db_opts, options, file_config):
|
||||
self._set_name(db)
|
||||
self.db = db
|
||||
self.db_opts = db_opts
|
||||
self.options = options
|
||||
self.file_config = file_config
|
||||
self.test_schema = "test_schema"
|
||||
self.test_schema_2 = "test_schema_2"
|
||||
|
||||
self.is_async = db.dialect.is_async and not util.asbool(
|
||||
db.url.query.get("async_fallback", False)
|
||||
)
|
||||
|
||||
_stack = collections.deque()
|
||||
_configs = set()
|
||||
|
||||
def _set_name(self, db):
|
||||
if db.dialect.server_version_info:
|
||||
svi = ".".join(str(tok) for tok in db.dialect.server_version_info)
|
||||
self.name = "%s+%s_[%s]" % (db.name, db.driver, svi)
|
||||
else:
|
||||
self.name = "%s+%s" % (db.name, db.driver)
|
||||
|
||||
@classmethod
|
||||
def register(cls, db, db_opts, options, file_config):
|
||||
"""add a config as one of the global configs.
|
||||
|
||||
If there are no configs set up yet, this config also
|
||||
gets set as the "_current".
|
||||
"""
|
||||
global any_async
|
||||
|
||||
cfg = Config(db, db_opts, options, file_config)
|
||||
|
||||
# if any backends include an async driver, then ensure
|
||||
# all setup/teardown and tests are wrapped in the maybe_async()
|
||||
# decorator that will set up a greenlet context for async drivers.
|
||||
any_async = any_async or cfg.is_async
|
||||
|
||||
cls._configs.add(cfg)
|
||||
return cfg
|
||||
|
||||
@classmethod
|
||||
def set_as_current(cls, config, namespace):
|
||||
global db, _current, db_url, test_schema, test_schema_2, db_opts
|
||||
_current = config
|
||||
db_url = config.db.url
|
||||
db_opts = config.db_opts
|
||||
test_schema = config.test_schema
|
||||
test_schema_2 = config.test_schema_2
|
||||
namespace.db = db = config.db
|
||||
|
||||
@classmethod
|
||||
def push_engine(cls, db, namespace):
|
||||
assert _current, "Can't push without a default Config set up"
|
||||
cls.push(
|
||||
Config(
|
||||
db, _current.db_opts, _current.options, _current.file_config
|
||||
),
|
||||
namespace,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def push(cls, config, namespace):
|
||||
cls._stack.append(_current)
|
||||
cls.set_as_current(config, namespace)
|
||||
|
||||
@classmethod
|
||||
def pop(cls, namespace):
|
||||
if cls._stack:
|
||||
# a failed test w/ -x option can call reset() ahead of time
|
||||
_current = cls._stack[-1]
|
||||
del cls._stack[-1]
|
||||
cls.set_as_current(_current, namespace)
|
||||
|
||||
@classmethod
|
||||
def reset(cls, namespace):
|
||||
if cls._stack:
|
||||
cls.set_as_current(cls._stack[0], namespace)
|
||||
cls._stack.clear()
|
||||
|
||||
@classmethod
|
||||
def all_configs(cls):
|
||||
return cls._configs
|
||||
|
||||
@classmethod
|
||||
def all_dbs(cls):
|
||||
for cfg in cls.all_configs():
|
||||
yield cfg.db
|
||||
|
||||
def skip_test(self, msg):
|
||||
skip_test(msg)
|
||||
|
||||
|
||||
def skip_test(msg):
|
||||
raise _fixture_functions.skip_test_exception(msg)
|
||||
|
||||
|
||||
def async_test(fn):
|
||||
return _fixture_functions.async_test(fn)
|
||||
465
lib/sqlalchemy/testing/engines.py
Normal file
465
lib/sqlalchemy/testing/engines.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# testing/engines.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import collections
|
||||
import re
|
||||
import warnings
|
||||
import weakref
|
||||
|
||||
from . import config
|
||||
from .util import decorator
|
||||
from .util import gc_collect
|
||||
from .. import event
|
||||
from .. import pool
|
||||
from ..util import await_only
|
||||
|
||||
|
||||
class ConnectionKiller(object):
|
||||
def __init__(self):
|
||||
self.proxy_refs = weakref.WeakKeyDictionary()
|
||||
self.testing_engines = collections.defaultdict(set)
|
||||
self.dbapi_connections = set()
|
||||
|
||||
def add_pool(self, pool):
|
||||
event.listen(pool, "checkout", self._add_conn)
|
||||
event.listen(pool, "checkin", self._remove_conn)
|
||||
event.listen(pool, "close", self._remove_conn)
|
||||
event.listen(pool, "close_detached", self._remove_conn)
|
||||
# note we are keeping "invalidated" here, as those are still
|
||||
# opened connections we would like to roll back
|
||||
|
||||
def _add_conn(self, dbapi_con, con_record, con_proxy):
|
||||
self.dbapi_connections.add(dbapi_con)
|
||||
self.proxy_refs[con_proxy] = True
|
||||
|
||||
def _remove_conn(self, dbapi_conn, *arg):
|
||||
self.dbapi_connections.discard(dbapi_conn)
|
||||
|
||||
def add_engine(self, engine, scope):
|
||||
self.add_pool(engine.pool)
|
||||
|
||||
assert scope in ("class", "global", "function", "fixture")
|
||||
self.testing_engines[scope].add(engine)
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"testing_reaper couldn't rollback/close connection: %s" % e
|
||||
)
|
||||
|
||||
def rollback_all(self):
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self._safe(rec.rollback)
|
||||
|
||||
def checkin_all(self):
|
||||
# run pool.checkin() for all ConnectionFairy instances we have
|
||||
# tracked.
|
||||
|
||||
for rec in list(self.proxy_refs):
|
||||
if rec is not None and rec.is_valid:
|
||||
self.dbapi_connections.discard(rec.dbapi_connection)
|
||||
self._safe(rec._checkin)
|
||||
|
||||
# for fairy refs that were GCed and could not close the connection,
|
||||
# such as asyncio, roll back those remaining connections
|
||||
for con in self.dbapi_connections:
|
||||
self._safe(con.rollback)
|
||||
self.dbapi_connections.clear()
|
||||
|
||||
def close_all(self):
|
||||
self.checkin_all()
|
||||
|
||||
def prepare_for_drop_tables(self, connection):
|
||||
# don't do aggressive checks for third party test suites
|
||||
if not config.bootstrapped_as_sqlalchemy:
|
||||
return
|
||||
|
||||
from . import provision
|
||||
|
||||
provision.prepare_for_drop_tables(connection.engine.url, connection)
|
||||
|
||||
def _drop_testing_engines(self, scope):
|
||||
eng = self.testing_engines[scope]
|
||||
for rec in list(eng):
|
||||
for proxy_ref in list(self.proxy_refs):
|
||||
if proxy_ref is not None and proxy_ref.is_valid:
|
||||
if (
|
||||
proxy_ref._pool is not None
|
||||
and proxy_ref._pool is rec.pool
|
||||
):
|
||||
self._safe(proxy_ref._checkin)
|
||||
if hasattr(rec, "sync_engine"):
|
||||
await_only(rec.dispose())
|
||||
else:
|
||||
rec.dispose()
|
||||
eng.clear()
|
||||
|
||||
def after_test(self):
|
||||
self._drop_testing_engines("function")
|
||||
|
||||
def after_test_outside_fixtures(self, test):
|
||||
# don't do aggressive checks for third party test suites
|
||||
if not config.bootstrapped_as_sqlalchemy:
|
||||
return
|
||||
|
||||
if test.__class__.__leave_connections_for_teardown__:
|
||||
return
|
||||
|
||||
self.checkin_all()
|
||||
|
||||
# on PostgreSQL, this will test for any "idle in transaction"
|
||||
# connections. useful to identify tests with unusual patterns
|
||||
# that can't be cleaned up correctly.
|
||||
from . import provision
|
||||
|
||||
with config.db.connect() as conn:
|
||||
provision.prepare_for_drop_tables(conn.engine.url, conn)
|
||||
|
||||
def stop_test_class_inside_fixtures(self):
|
||||
self.checkin_all()
|
||||
self._drop_testing_engines("function")
|
||||
self._drop_testing_engines("class")
|
||||
|
||||
def stop_test_class_outside_fixtures(self):
|
||||
# ensure no refs to checked out connections at all.
|
||||
|
||||
if pool.base._strong_ref_connection_records:
|
||||
gc_collect()
|
||||
|
||||
if pool.base._strong_ref_connection_records:
|
||||
ln = len(pool.base._strong_ref_connection_records)
|
||||
pool.base._strong_ref_connection_records.clear()
|
||||
assert (
|
||||
False
|
||||
), "%d connection recs not cleared after test suite" % (ln)
|
||||
|
||||
def final_cleanup(self):
|
||||
self.checkin_all()
|
||||
for scope in self.testing_engines:
|
||||
self._drop_testing_engines(scope)
|
||||
|
||||
def assert_all_closed(self):
|
||||
for rec in self.proxy_refs:
|
||||
if rec.is_valid:
|
||||
assert False
|
||||
|
||||
|
||||
testing_reaper = ConnectionKiller()
|
||||
|
||||
|
||||
@decorator
|
||||
def assert_conns_closed(fn, *args, **kw):
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.assert_all_closed()
|
||||
|
||||
|
||||
@decorator
|
||||
def rollback_open_connections(fn, *args, **kw):
|
||||
"""Decorator that rolls back all open connections after fn execution."""
|
||||
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.rollback_all()
|
||||
|
||||
|
||||
@decorator
|
||||
def close_first(fn, *args, **kw):
|
||||
"""Decorator that closes all connections before fn execution."""
|
||||
|
||||
testing_reaper.checkin_all()
|
||||
fn(*args, **kw)
|
||||
|
||||
|
||||
@decorator
|
||||
def close_open_connections(fn, *args, **kw):
|
||||
"""Decorator that closes all connections after fn execution."""
|
||||
try:
|
||||
fn(*args, **kw)
|
||||
finally:
|
||||
testing_reaper.checkin_all()
|
||||
|
||||
|
||||
def all_dialects(exclude=None):
|
||||
import sqlalchemy.dialects as d
|
||||
|
||||
for name in d.__all__:
|
||||
# TEMPORARY
|
||||
if exclude and name in exclude:
|
||||
continue
|
||||
mod = getattr(d, name, None)
|
||||
if not mod:
|
||||
mod = getattr(
|
||||
__import__("sqlalchemy.dialects.%s" % name).dialects, name
|
||||
)
|
||||
yield mod.dialect()
|
||||
|
||||
|
||||
class ReconnectFixture(object):
|
||||
def __init__(self, dbapi):
|
||||
self.dbapi = dbapi
|
||||
self.connections = []
|
||||
self.is_stopped = False
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.dbapi, key)
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
|
||||
conn = self.dbapi.connect(*args, **kwargs)
|
||||
if self.is_stopped:
|
||||
self._safe(conn.close)
|
||||
curs = conn.cursor() # should fail on Oracle etc.
|
||||
# should fail for everything that didn't fail
|
||||
# above, connection is closed
|
||||
curs.execute("select 1")
|
||||
assert False, "simulated connect failure didn't work"
|
||||
else:
|
||||
self.connections.append(conn)
|
||||
return conn
|
||||
|
||||
def _safe(self, fn):
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
|
||||
|
||||
def shutdown(self, stop=False):
|
||||
# TODO: this doesn't cover all cases
|
||||
# as nicely as we'd like, namely MySQLdb.
|
||||
# would need to implement R. Brewer's
|
||||
# proxy server idea to get better
|
||||
# coverage.
|
||||
self.is_stopped = stop
|
||||
for c in list(self.connections):
|
||||
self._safe(c.close)
|
||||
self.connections = []
|
||||
|
||||
def restart(self):
|
||||
self.is_stopped = False
|
||||
|
||||
|
||||
def reconnecting_engine(url=None, options=None):
|
||||
url = url or config.db.url
|
||||
dbapi = config.db.dialect.dbapi
|
||||
if not options:
|
||||
options = {}
|
||||
options["module"] = ReconnectFixture(dbapi)
|
||||
engine = testing_engine(url, options)
|
||||
_dispose = engine.dispose
|
||||
|
||||
def dispose():
|
||||
engine.dialect.dbapi.shutdown()
|
||||
engine.dialect.dbapi.is_stopped = False
|
||||
_dispose()
|
||||
|
||||
engine.test_shutdown = engine.dialect.dbapi.shutdown
|
||||
engine.test_restart = engine.dialect.dbapi.restart
|
||||
engine.dispose = dispose
|
||||
return engine
|
||||
|
||||
|
||||
def testing_engine(
|
||||
url=None,
|
||||
options=None,
|
||||
future=None,
|
||||
asyncio=False,
|
||||
transfer_staticpool=False,
|
||||
_sqlite_savepoint=False,
|
||||
):
|
||||
"""Produce an engine configured by --options with optional overrides."""
|
||||
|
||||
if asyncio:
|
||||
assert not _sqlite_savepoint
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine as create_engine,
|
||||
)
|
||||
elif future or (
|
||||
config.db and config.db._is_future and future is not False
|
||||
):
|
||||
from sqlalchemy.future import create_engine
|
||||
else:
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
if not options:
|
||||
use_reaper = True
|
||||
scope = "function"
|
||||
sqlite_savepoint = False
|
||||
else:
|
||||
use_reaper = options.pop("use_reaper", True)
|
||||
scope = options.pop("scope", "function")
|
||||
sqlite_savepoint = options.pop("sqlite_savepoint", False)
|
||||
|
||||
url = url or config.db.url
|
||||
|
||||
url = make_url(url)
|
||||
if options is None:
|
||||
if config.db is None or url.drivername == config.db.url.drivername:
|
||||
options = config.db_opts
|
||||
else:
|
||||
options = {}
|
||||
elif config.db is not None and url.drivername == config.db.url.drivername:
|
||||
default_opt = config.db_opts.copy()
|
||||
default_opt.update(options)
|
||||
|
||||
engine = create_engine(url, **options)
|
||||
|
||||
if sqlite_savepoint and engine.name == "sqlite":
|
||||
# apply SQLite savepoint workaround
|
||||
@event.listens_for(engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
conn.exec_driver_sql("BEGIN")
|
||||
|
||||
if transfer_staticpool:
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
if config.db is not None and isinstance(config.db.pool, StaticPool):
|
||||
use_reaper = False
|
||||
engine.pool._transfer_from(config.db.pool)
|
||||
|
||||
if scope == "global":
|
||||
if asyncio:
|
||||
engine.sync_engine._has_events = True
|
||||
else:
|
||||
engine._has_events = (
|
||||
True # enable event blocks, helps with profiling
|
||||
)
|
||||
|
||||
if isinstance(engine.pool, pool.QueuePool):
|
||||
engine.pool._timeout = 0
|
||||
engine.pool._max_overflow = 0
|
||||
if use_reaper:
|
||||
testing_reaper.add_engine(engine, scope)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def mock_engine(dialect_name=None):
|
||||
"""Provides a mocking engine based on the current testing.db.
|
||||
|
||||
This is normally used to test DDL generation flow as emitted
|
||||
by an Engine.
|
||||
|
||||
It should not be used in other cases, as assert_compile() and
|
||||
assert_sql_execution() are much better choices with fewer
|
||||
moving parts.
|
||||
|
||||
"""
|
||||
|
||||
from sqlalchemy import create_mock_engine
|
||||
|
||||
if not dialect_name:
|
||||
dialect_name = config.db.name
|
||||
|
||||
buffer = []
|
||||
|
||||
def executor(sql, *a, **kw):
|
||||
buffer.append(sql)
|
||||
|
||||
def assert_sql(stmts):
|
||||
recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
|
||||
assert recv == stmts, recv
|
||||
|
||||
def print_sql():
|
||||
d = engine.dialect
|
||||
return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
|
||||
|
||||
engine = create_mock_engine(dialect_name + "://", executor)
|
||||
assert not hasattr(engine, "mock")
|
||||
engine.mock = buffer
|
||||
engine.assert_sql = assert_sql
|
||||
engine.print_sql = print_sql
|
||||
return engine
|
||||
|
||||
|
||||
class DBAPIProxyCursor(object):
|
||||
"""Proxy a DBAPI cursor.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level cursor operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, conn, *args, **kwargs):
|
||||
self.engine = engine
|
||||
self.connection = conn
|
||||
self.cursor = conn.cursor(*args, **kwargs)
|
||||
|
||||
def execute(self, stmt, parameters=None, **kw):
|
||||
if parameters:
|
||||
return self.cursor.execute(stmt, parameters, **kw)
|
||||
else:
|
||||
return self.cursor.execute(stmt, **kw)
|
||||
|
||||
def executemany(self, stmt, params, **kw):
|
||||
return self.cursor.executemany(stmt, params, **kw)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.cursor, key)
|
||||
|
||||
|
||||
class DBAPIProxyConnection(object):
|
||||
"""Proxy a DBAPI connection.
|
||||
|
||||
Tests can provide subclasses of this to intercept
|
||||
DBAPI-level connection operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, engine, cursor_cls):
|
||||
self.conn = engine.pool._creator()
|
||||
self.engine = engine
|
||||
self.cursor_cls = cursor_cls
|
||||
|
||||
def cursor(self, *args, **kwargs):
|
||||
return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self.conn.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.conn, key)
|
||||
|
||||
|
||||
def proxying_engine(
|
||||
conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
|
||||
):
|
||||
"""Produce an engine that provides proxy hooks for
|
||||
common methods.
|
||||
|
||||
"""
|
||||
|
||||
def mock_conn():
|
||||
return conn_cls(config.db, cursor_cls)
|
||||
|
||||
def _wrap_do_on_connect(do_on_connect):
|
||||
def go(dbapi_conn):
|
||||
return do_on_connect(dbapi_conn.conn)
|
||||
|
||||
return go
|
||||
|
||||
return testing_engine(
|
||||
options={
|
||||
"creator": mock_conn,
|
||||
"_wrap_do_on_connect": _wrap_do_on_connect,
|
||||
}
|
||||
)
|
||||
111
lib/sqlalchemy/testing/entities.py
Normal file
111
lib/sqlalchemy/testing/entities.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# testing/entities.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import sqlalchemy as sa
|
||||
from .. import exc as sa_exc
|
||||
from ..util import compat
|
||||
|
||||
_repr_stack = set()
|
||||
|
||||
|
||||
class BasicEntity(object):
|
||||
def __init__(self, **kw):
|
||||
for key, value in kw.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def __repr__(self):
|
||||
if id(self) in _repr_stack:
|
||||
return object.__repr__(self)
|
||||
_repr_stack.add(id(self))
|
||||
try:
|
||||
return "%s(%s)" % (
|
||||
(self.__class__.__name__),
|
||||
", ".join(
|
||||
[
|
||||
"%s=%r" % (key, getattr(self, key))
|
||||
for key in sorted(self.__dict__.keys())
|
||||
if not key.startswith("_")
|
||||
]
|
||||
),
|
||||
)
|
||||
finally:
|
||||
_repr_stack.remove(id(self))
|
||||
|
||||
|
||||
_recursion_stack = set()
|
||||
|
||||
|
||||
class ComparableMixin(object):
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""'Deep, sparse compare.
|
||||
|
||||
Deeply compare two entities, following the non-None attributes of the
|
||||
non-persisted object, if possible.
|
||||
|
||||
"""
|
||||
if other is self:
|
||||
return True
|
||||
elif not self.__class__ == other.__class__:
|
||||
return False
|
||||
|
||||
if id(self) in _recursion_stack:
|
||||
return True
|
||||
_recursion_stack.add(id(self))
|
||||
|
||||
try:
|
||||
# pick the entity that's not SA persisted as the source
|
||||
try:
|
||||
self_key = sa.orm.attributes.instance_state(self).key
|
||||
except sa.orm.exc.NO_STATE:
|
||||
self_key = None
|
||||
|
||||
if other is None:
|
||||
a = self
|
||||
b = other
|
||||
elif self_key is not None:
|
||||
a = other
|
||||
b = self
|
||||
else:
|
||||
a = self
|
||||
b = other
|
||||
|
||||
for attr in list(a.__dict__):
|
||||
if attr.startswith("_"):
|
||||
continue
|
||||
value = getattr(a, attr)
|
||||
|
||||
try:
|
||||
# handle lazy loader errors
|
||||
battr = getattr(b, attr)
|
||||
except (AttributeError, sa_exc.UnboundExecutionError):
|
||||
return False
|
||||
|
||||
if hasattr(value, "__iter__") and not isinstance(
|
||||
value, compat.string_types
|
||||
):
|
||||
if hasattr(value, "__getitem__") and not hasattr(
|
||||
value, "keys"
|
||||
):
|
||||
if list(value) != list(battr):
|
||||
return False
|
||||
else:
|
||||
if set(value) != set(battr):
|
||||
return False
|
||||
else:
|
||||
if value is not None and value != battr:
|
||||
return False
|
||||
return True
|
||||
finally:
|
||||
_recursion_stack.remove(id(self))
|
||||
|
||||
|
||||
class ComparableEntity(ComparableMixin, BasicEntity):
|
||||
def __hash__(self):
|
||||
return hash(self.__class__)
|
||||
465
lib/sqlalchemy/testing/exclusions.py
Normal file
465
lib/sqlalchemy/testing/exclusions.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# testing/exclusions.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
import contextlib
|
||||
import operator
|
||||
import re
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from .. import util
|
||||
from ..util import decorator
|
||||
from ..util.compat import inspect_getfullargspec
|
||||
|
||||
|
||||
def skip_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.skips.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
def fails_if(predicate, reason=None):
|
||||
rule = compound()
|
||||
pred = _as_predicate(predicate, reason)
|
||||
rule.fails.add(pred)
|
||||
return rule
|
||||
|
||||
|
||||
class compound(object):
|
||||
def __init__(self):
|
||||
self.fails = set()
|
||||
self.skips = set()
|
||||
self.tags = set()
|
||||
|
||||
def __add__(self, other):
|
||||
return self.add(other)
|
||||
|
||||
def as_skips(self):
|
||||
rule = compound()
|
||||
rule.skips.update(self.skips)
|
||||
rule.skips.update(self.fails)
|
||||
rule.tags.update(self.tags)
|
||||
return rule
|
||||
|
||||
def add(self, *others):
|
||||
copy = compound()
|
||||
copy.fails.update(self.fails)
|
||||
copy.skips.update(self.skips)
|
||||
copy.tags.update(self.tags)
|
||||
for other in others:
|
||||
copy.fails.update(other.fails)
|
||||
copy.skips.update(other.skips)
|
||||
copy.tags.update(other.tags)
|
||||
return copy
|
||||
|
||||
def not_(self):
|
||||
copy = compound()
|
||||
copy.fails.update(NotPredicate(fail) for fail in self.fails)
|
||||
copy.skips.update(NotPredicate(skip) for skip in self.skips)
|
||||
copy.tags.update(self.tags)
|
||||
return copy
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self.enabled_for_config(config._current)
|
||||
|
||||
def enabled_for_config(self, config):
|
||||
for predicate in self.skips.union(self.fails):
|
||||
if predicate(config):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def matching_config_reasons(self, config):
|
||||
return [
|
||||
predicate._as_string(config)
|
||||
for predicate in self.skips.union(self.fails)
|
||||
if predicate(config)
|
||||
]
|
||||
|
||||
def include_test(self, include_tags, exclude_tags):
|
||||
return bool(
|
||||
not self.tags.intersection(exclude_tags)
|
||||
and (not include_tags or self.tags.intersection(include_tags))
|
||||
)
|
||||
|
||||
def _extend(self, other):
|
||||
self.skips.update(other.skips)
|
||||
self.fails.update(other.fails)
|
||||
self.tags.update(other.tags)
|
||||
|
||||
def __call__(self, fn):
|
||||
if hasattr(fn, "_sa_exclusion_extend"):
|
||||
fn._sa_exclusion_extend._extend(self)
|
||||
return fn
|
||||
|
||||
@decorator
|
||||
def decorate(fn, *args, **kw):
|
||||
return self._do(config._current, fn, *args, **kw)
|
||||
|
||||
decorated = decorate(fn)
|
||||
decorated._sa_exclusion_extend = self
|
||||
return decorated
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fail_if(self):
|
||||
all_fails = compound()
|
||||
all_fails.fails.update(self.skips.union(self.fails))
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception as ex:
|
||||
all_fails._expect_failure(config._current, ex)
|
||||
else:
|
||||
all_fails._expect_success(config._current)
|
||||
|
||||
def _do(self, cfg, fn, *args, **kw):
|
||||
for skip in self.skips:
|
||||
if skip(cfg):
|
||||
msg = "'%s' : %s" % (
|
||||
config.get_current_test_name(),
|
||||
skip._as_string(cfg),
|
||||
)
|
||||
config.skip_test(msg)
|
||||
|
||||
try:
|
||||
return_value = fn(*args, **kw)
|
||||
except Exception as ex:
|
||||
self._expect_failure(cfg, ex, name=fn.__name__)
|
||||
else:
|
||||
self._expect_success(cfg, name=fn.__name__)
|
||||
return return_value
|
||||
|
||||
def _expect_failure(self, config, ex, name="block"):
|
||||
for fail in self.fails:
|
||||
if fail(config):
|
||||
if util.py2k:
|
||||
str_ex = unicode(ex).encode( # noqa: F821
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
else:
|
||||
str_ex = str(ex)
|
||||
print(
|
||||
(
|
||||
"%s failed as expected (%s): %s "
|
||||
% (name, fail._as_string(config), str_ex)
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
util.raise_(ex, with_traceback=sys.exc_info()[2])
|
||||
|
||||
def _expect_success(self, config, name="block"):
|
||||
if not self.fails:
|
||||
return
|
||||
|
||||
for fail in self.fails:
|
||||
if fail(config):
|
||||
raise AssertionError(
|
||||
"Unexpected success for '%s' (%s)"
|
||||
% (
|
||||
name,
|
||||
" and ".join(
|
||||
fail._as_string(config) for fail in self.fails
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def requires_tag(tagname):
|
||||
return tags([tagname])
|
||||
|
||||
|
||||
def tags(tagnames):
|
||||
comp = compound()
|
||||
comp.tags.update(tagnames)
|
||||
return comp
|
||||
|
||||
|
||||
def only_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return skip_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
def succeeds_if(predicate, reason=None):
|
||||
predicate = _as_predicate(predicate)
|
||||
return fails_if(NotPredicate(predicate), reason)
|
||||
|
||||
|
||||
class Predicate(object):
|
||||
@classmethod
|
||||
def as_predicate(cls, predicate, description=None):
|
||||
if isinstance(predicate, compound):
|
||||
return cls.as_predicate(predicate.enabled_for_config, description)
|
||||
elif isinstance(predicate, Predicate):
|
||||
if description and predicate.description is None:
|
||||
predicate.description = description
|
||||
return predicate
|
||||
elif isinstance(predicate, (list, set)):
|
||||
return OrPredicate(
|
||||
[cls.as_predicate(pred) for pred in predicate], description
|
||||
)
|
||||
elif isinstance(predicate, tuple):
|
||||
return SpecPredicate(*predicate)
|
||||
elif isinstance(predicate, util.string_types):
|
||||
tokens = re.match(
|
||||
r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
|
||||
)
|
||||
if not tokens:
|
||||
raise ValueError(
|
||||
"Couldn't locate DB name in predicate: %r" % predicate
|
||||
)
|
||||
db = tokens.group(1)
|
||||
op = tokens.group(2)
|
||||
spec = (
|
||||
tuple(int(d) for d in tokens.group(3).split("."))
|
||||
if tokens.group(3)
|
||||
else None
|
||||
)
|
||||
|
||||
return SpecPredicate(db, op, spec, description=description)
|
||||
elif callable(predicate):
|
||||
return LambdaPredicate(predicate, description)
|
||||
else:
|
||||
assert False, "unknown predicate type: %s" % predicate
|
||||
|
||||
def _format_description(self, config, negate=False):
|
||||
bool_ = self(config)
|
||||
if negate:
|
||||
bool_ = not negate
|
||||
return self.description % {
|
||||
"driver": config.db.url.get_driver_name()
|
||||
if config
|
||||
else "<no driver>",
|
||||
"database": config.db.url.get_backend_name()
|
||||
if config
|
||||
else "<no database>",
|
||||
"doesnt_support": "doesn't support" if bool_ else "does support",
|
||||
"does_support": "does support" if bool_ else "doesn't support",
|
||||
}
|
||||
|
||||
def _as_string(self, config=None, negate=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BooleanPredicate(Predicate):
|
||||
def __init__(self, value, description=None):
|
||||
self.value = value
|
||||
self.description = description or "boolean %s" % value
|
||||
|
||||
def __call__(self, config):
|
||||
return self.value
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config, negate=negate)
|
||||
|
||||
|
||||
class SpecPredicate(Predicate):
|
||||
def __init__(self, db, op=None, spec=None, description=None):
|
||||
self.db = db
|
||||
self.op = op
|
||||
self.spec = spec
|
||||
self.description = description
|
||||
|
||||
_ops = {
|
||||
"<": operator.lt,
|
||||
">": operator.gt,
|
||||
"==": operator.eq,
|
||||
"!=": operator.ne,
|
||||
"<=": operator.le,
|
||||
">=": operator.ge,
|
||||
"in": operator.contains,
|
||||
"between": lambda val, pair: val >= pair[0] and val <= pair[1],
|
||||
}
|
||||
|
||||
def __call__(self, config):
|
||||
if config is None:
|
||||
return False
|
||||
|
||||
engine = config.db
|
||||
|
||||
if "+" in self.db:
|
||||
dialect, driver = self.db.split("+")
|
||||
else:
|
||||
dialect, driver = self.db, None
|
||||
|
||||
if dialect and engine.name != dialect:
|
||||
return False
|
||||
if driver is not None and engine.driver != driver:
|
||||
return False
|
||||
|
||||
if self.op is not None:
|
||||
assert driver is None, "DBAPI version specs not supported yet"
|
||||
|
||||
version = _server_version(engine)
|
||||
oper = (
|
||||
hasattr(self.op, "__call__") and self.op or self._ops[self.op]
|
||||
)
|
||||
return oper(version, self.spec)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description is not None:
|
||||
return self._format_description(config)
|
||||
elif self.op is None:
|
||||
if negate:
|
||||
return "not %s" % self.db
|
||||
else:
|
||||
return "%s" % self.db
|
||||
else:
|
||||
if negate:
|
||||
return "not %s %s %s" % (self.db, self.op, self.spec)
|
||||
else:
|
||||
return "%s %s %s" % (self.db, self.op, self.spec)
|
||||
|
||||
|
||||
class LambdaPredicate(Predicate):
|
||||
def __init__(self, lambda_, description=None, args=None, kw=None):
|
||||
spec = inspect_getfullargspec(lambda_)
|
||||
if not spec[0]:
|
||||
self.lambda_ = lambda db: lambda_()
|
||||
else:
|
||||
self.lambda_ = lambda_
|
||||
self.args = args or ()
|
||||
self.kw = kw or {}
|
||||
if description:
|
||||
self.description = description
|
||||
elif lambda_.__doc__:
|
||||
self.description = lambda_.__doc__
|
||||
else:
|
||||
self.description = "custom function"
|
||||
|
||||
def __call__(self, config):
|
||||
return self.lambda_(config)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
return self._format_description(config)
|
||||
|
||||
|
||||
class NotPredicate(Predicate):
|
||||
def __init__(self, predicate, description=None):
|
||||
self.predicate = predicate
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
return not self.predicate(config)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if self.description:
|
||||
return self._format_description(config, not negate)
|
||||
else:
|
||||
return self.predicate._as_string(config, not negate)
|
||||
|
||||
|
||||
class OrPredicate(Predicate):
|
||||
def __init__(self, predicates, description=None):
|
||||
self.predicates = predicates
|
||||
self.description = description
|
||||
|
||||
def __call__(self, config):
|
||||
for pred in self.predicates:
|
||||
if pred(config):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _eval_str(self, config, negate=False):
|
||||
if negate:
|
||||
conjunction = " and "
|
||||
else:
|
||||
conjunction = " or "
|
||||
return conjunction.join(
|
||||
p._as_string(config, negate=negate) for p in self.predicates
|
||||
)
|
||||
|
||||
def _negation_str(self, config):
|
||||
if self.description is not None:
|
||||
return "Not " + self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config, negate=True)
|
||||
|
||||
def _as_string(self, config, negate=False):
|
||||
if negate:
|
||||
return self._negation_str(config)
|
||||
else:
|
||||
if self.description is not None:
|
||||
return self._format_description(config)
|
||||
else:
|
||||
return self._eval_str(config)
|
||||
|
||||
|
||||
_as_predicate = Predicate.as_predicate
|
||||
|
||||
|
||||
def _is_excluded(db, op, spec):
|
||||
return SpecPredicate(db, op, spec)(config._current)
|
||||
|
||||
|
||||
def _server_version(engine):
|
||||
"""Return a server_version_info tuple."""
|
||||
|
||||
# force metadata to be retrieved
|
||||
conn = engine.connect()
|
||||
version = getattr(engine.dialect, "server_version_info", None)
|
||||
if version is None:
|
||||
version = ()
|
||||
conn.close()
|
||||
return version
|
||||
|
||||
|
||||
def db_spec(*dbs):
|
||||
return OrPredicate([Predicate.as_predicate(db) for db in dbs])
|
||||
|
||||
|
||||
def open(): # noqa
|
||||
return skip_if(BooleanPredicate(False, "mark as execute"))
|
||||
|
||||
|
||||
def closed():
|
||||
return skip_if(BooleanPredicate(True, "marked as skip"))
|
||||
|
||||
|
||||
def fails(reason=None):
|
||||
return fails_if(BooleanPredicate(True, reason or "expected to fail"))
|
||||
|
||||
|
||||
@decorator
|
||||
def future(fn, *arg):
|
||||
return fails_if(LambdaPredicate(fn), "Future feature")
|
||||
|
||||
|
||||
def fails_on(db, reason=None):
|
||||
return fails_if(db, reason)
|
||||
|
||||
|
||||
def fails_on_everything_except(*dbs):
|
||||
return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
|
||||
|
||||
|
||||
def skip(db, reason=None):
|
||||
return skip_if(db, reason)
|
||||
|
||||
|
||||
def only_on(dbs, reason=None):
|
||||
return only_if(
|
||||
OrPredicate(
|
||||
[Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def exclude(db, op, spec, reason=None):
|
||||
return skip_if(SpecPredicate(db, op, spec), reason)
|
||||
|
||||
|
||||
def against(config, *queries):
|
||||
assert queries, "no queries sent!"
|
||||
return OrPredicate([Predicate.as_predicate(query) for query in queries])(
|
||||
config
|
||||
)
|
||||
870
lib/sqlalchemy/testing/fixtures.py
Normal file
870
lib/sqlalchemy/testing/fixtures.py
Normal file
@@ -0,0 +1,870 @@
|
||||
# testing/fixtures.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import sys
|
||||
|
||||
import sqlalchemy as sa
|
||||
from . import assertions
|
||||
from . import config
|
||||
from . import schema
|
||||
from .entities import BasicEntity
|
||||
from .entities import ComparableEntity
|
||||
from .entities import ComparableMixin # noqa
|
||||
from .util import adict
|
||||
from .util import drop_all_tables_from_metadata
|
||||
from .. import event
|
||||
from .. import util
|
||||
from ..orm import declarative_base
|
||||
from ..orm import registry
|
||||
from ..orm.decl_api import DeclarativeMeta
|
||||
from ..schema import sort_tables_and_constraints
|
||||
|
||||
|
||||
@config.mark_base_test_class()
|
||||
class TestBase(object):
|
||||
# A sequence of requirement names matching testing.requires decorators
|
||||
__requires__ = ()
|
||||
|
||||
# A sequence of dialect names to exclude from the test class.
|
||||
__unsupported_on__ = ()
|
||||
|
||||
# If present, test class is only runnable for the *single* specified
|
||||
# dialect. If you need multiple, use __unsupported_on__ and invert.
|
||||
__only_on__ = None
|
||||
|
||||
# A sequence of no-arg callables. If any are True, the entire testcase is
|
||||
# skipped.
|
||||
__skip_if__ = None
|
||||
|
||||
# if True, the testing reaper will not attempt to touch connection
|
||||
# state after a test is completed and before the outer teardown
|
||||
# starts
|
||||
__leave_connections_for_teardown__ = False
|
||||
|
||||
def assert_(self, val, msg=None):
|
||||
assert val, msg
|
||||
|
||||
@config.fixture()
|
||||
def nocache(self):
|
||||
_cache = config.db._compiled_cache
|
||||
config.db._compiled_cache = None
|
||||
yield
|
||||
config.db._compiled_cache = _cache
|
||||
|
||||
@config.fixture()
|
||||
def connection_no_trans(self):
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
with eng.connect() as conn:
|
||||
yield conn
|
||||
|
||||
@config.fixture()
|
||||
def connection(self):
|
||||
global _connection_fixture_connection
|
||||
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
conn = eng.connect()
|
||||
trans = conn.begin()
|
||||
|
||||
_connection_fixture_connection = conn
|
||||
yield conn
|
||||
|
||||
_connection_fixture_connection = None
|
||||
|
||||
if trans.is_active:
|
||||
trans.rollback()
|
||||
# trans would not be active here if the test is using
|
||||
# the legacy @provide_metadata decorator still, as it will
|
||||
# run a close all connections.
|
||||
conn.close()
|
||||
|
||||
@config.fixture()
|
||||
def close_result_when_finished(self):
|
||||
to_close = []
|
||||
to_consume = []
|
||||
|
||||
def go(result, consume=False):
|
||||
to_close.append(result)
|
||||
if consume:
|
||||
to_consume.append(result)
|
||||
|
||||
yield go
|
||||
for r in to_consume:
|
||||
try:
|
||||
r.all()
|
||||
except:
|
||||
pass
|
||||
for r in to_close:
|
||||
try:
|
||||
r.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
@config.fixture()
|
||||
def registry(self, metadata):
|
||||
reg = registry(metadata=metadata)
|
||||
yield reg
|
||||
reg.dispose()
|
||||
|
||||
@config.fixture
|
||||
def decl_base(self, registry):
|
||||
return registry.generate_base()
|
||||
|
||||
@config.fixture()
|
||||
def future_connection(self, future_engine, connection):
|
||||
# integrate the future_engine and connection fixtures so
|
||||
# that users of the "connection" fixture will get at the
|
||||
# "future" connection
|
||||
yield connection
|
||||
|
||||
@config.fixture()
|
||||
def future_engine(self):
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
with _push_future_engine(eng):
|
||||
yield
|
||||
|
||||
@config.fixture()
|
||||
def testing_engine(self):
|
||||
from . import engines
|
||||
|
||||
def gen_testing_engine(
|
||||
url=None,
|
||||
options=None,
|
||||
future=None,
|
||||
asyncio=False,
|
||||
transfer_staticpool=False,
|
||||
):
|
||||
if options is None:
|
||||
options = {}
|
||||
options["scope"] = "fixture"
|
||||
return engines.testing_engine(
|
||||
url=url,
|
||||
options=options,
|
||||
future=future,
|
||||
asyncio=asyncio,
|
||||
transfer_staticpool=transfer_staticpool,
|
||||
)
|
||||
|
||||
yield gen_testing_engine
|
||||
|
||||
engines.testing_reaper._drop_testing_engines("fixture")
|
||||
|
||||
@config.fixture()
|
||||
def async_testing_engine(self, testing_engine):
|
||||
def go(**kw):
|
||||
kw["asyncio"] = True
|
||||
return testing_engine(**kw)
|
||||
|
||||
return go
|
||||
|
||||
@config.fixture
|
||||
def fixture_session(self):
|
||||
return fixture_session()
|
||||
|
||||
@config.fixture()
|
||||
def metadata(self, request):
|
||||
"""Provide bound MetaData for a single test, dropping afterwards."""
|
||||
|
||||
from ..sql import schema
|
||||
|
||||
metadata = schema.MetaData()
|
||||
request.instance.metadata = metadata
|
||||
yield metadata
|
||||
del request.instance.metadata
|
||||
|
||||
if (
|
||||
_connection_fixture_connection
|
||||
and _connection_fixture_connection.in_transaction()
|
||||
):
|
||||
trans = _connection_fixture_connection.get_transaction()
|
||||
trans.rollback()
|
||||
with _connection_fixture_connection.begin():
|
||||
drop_all_tables_from_metadata(
|
||||
metadata, _connection_fixture_connection
|
||||
)
|
||||
else:
|
||||
drop_all_tables_from_metadata(metadata, config.db)
|
||||
|
||||
@config.fixture(
|
||||
params=[
|
||||
(rollback, second_operation, begin_nested)
|
||||
for rollback in (True, False)
|
||||
for second_operation in ("none", "execute", "begin")
|
||||
for begin_nested in (
|
||||
True,
|
||||
False,
|
||||
)
|
||||
]
|
||||
)
|
||||
def trans_ctx_manager_fixture(self, request, metadata):
|
||||
rollback, second_operation, begin_nested = request.param
|
||||
|
||||
from sqlalchemy import Table, Column, Integer, func, select
|
||||
from . import eq_
|
||||
|
||||
t = Table("test", metadata, Column("data", Integer))
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
|
||||
t.create(eng)
|
||||
|
||||
def run_test(subject, trans_on_subject, execute_on_subject):
|
||||
with subject.begin() as trans:
|
||||
|
||||
if begin_nested:
|
||||
if not config.requirements.savepoints.enabled:
|
||||
config.skip_test("savepoints not enabled")
|
||||
if execute_on_subject:
|
||||
nested_trans = subject.begin_nested()
|
||||
else:
|
||||
nested_trans = trans.begin_nested()
|
||||
|
||||
with nested_trans:
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 10})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 10})
|
||||
|
||||
# for nested trans, we always commit/rollback on the
|
||||
# "nested trans" object itself.
|
||||
# only Session(future=False) will affect savepoint
|
||||
# transaction for session.commit/rollback
|
||||
|
||||
if rollback:
|
||||
nested_trans.rollback()
|
||||
else:
|
||||
nested_trans.commit()
|
||||
|
||||
if second_operation != "none":
|
||||
with assertions.expect_raises_message(
|
||||
sa.exc.InvalidRequestError,
|
||||
"Can't operate on closed transaction "
|
||||
"inside context "
|
||||
"manager. Please complete the context "
|
||||
"manager "
|
||||
"before emitting further commands.",
|
||||
):
|
||||
if second_operation == "execute":
|
||||
if execute_on_subject:
|
||||
subject.execute(
|
||||
t.insert(), {"data": 12}
|
||||
)
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 12})
|
||||
elif second_operation == "begin":
|
||||
if execute_on_subject:
|
||||
subject.begin_nested()
|
||||
else:
|
||||
trans.begin_nested()
|
||||
|
||||
# outside the nested trans block, but still inside the
|
||||
# transaction block, we can run SQL, and it will be
|
||||
# committed
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 14})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 14})
|
||||
|
||||
else:
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 10})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 10})
|
||||
|
||||
if trans_on_subject:
|
||||
if rollback:
|
||||
subject.rollback()
|
||||
else:
|
||||
subject.commit()
|
||||
else:
|
||||
if rollback:
|
||||
trans.rollback()
|
||||
else:
|
||||
trans.commit()
|
||||
|
||||
if second_operation != "none":
|
||||
with assertions.expect_raises_message(
|
||||
sa.exc.InvalidRequestError,
|
||||
"Can't operate on closed transaction inside "
|
||||
"context "
|
||||
"manager. Please complete the context manager "
|
||||
"before emitting further commands.",
|
||||
):
|
||||
if second_operation == "execute":
|
||||
if execute_on_subject:
|
||||
subject.execute(t.insert(), {"data": 12})
|
||||
else:
|
||||
trans.execute(t.insert(), {"data": 12})
|
||||
elif second_operation == "begin":
|
||||
if hasattr(trans, "begin"):
|
||||
trans.begin()
|
||||
else:
|
||||
subject.begin()
|
||||
elif second_operation == "begin_nested":
|
||||
if execute_on_subject:
|
||||
subject.begin_nested()
|
||||
else:
|
||||
trans.begin_nested()
|
||||
|
||||
expected_committed = 0
|
||||
if begin_nested:
|
||||
# begin_nested variant, we inserted a row after the nested
|
||||
# block
|
||||
expected_committed += 1
|
||||
if not rollback:
|
||||
# not rollback variant, our row inserted in the target
|
||||
# block itself would be committed
|
||||
expected_committed += 1
|
||||
|
||||
if execute_on_subject:
|
||||
eq_(
|
||||
subject.scalar(select(func.count()).select_from(t)),
|
||||
expected_committed,
|
||||
)
|
||||
else:
|
||||
with subject.connect() as conn:
|
||||
eq_(
|
||||
conn.scalar(select(func.count()).select_from(t)),
|
||||
expected_committed,
|
||||
)
|
||||
|
||||
return run_test
|
||||
|
||||
|
||||
_connection_fixture_connection = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _push_future_engine(engine):
|
||||
|
||||
from ..future.engine import Engine
|
||||
from sqlalchemy import testing
|
||||
|
||||
facade = Engine._future_facade(engine)
|
||||
config._current.push_engine(facade, testing)
|
||||
|
||||
yield facade
|
||||
|
||||
config._current.pop(testing)
|
||||
|
||||
|
||||
class FutureEngineMixin(object):
|
||||
@config.fixture(autouse=True, scope="class")
|
||||
def _push_future_engine(self):
|
||||
eng = getattr(self, "bind", None) or config.db
|
||||
with _push_future_engine(eng):
|
||||
yield
|
||||
|
||||
|
||||
class TablesTest(TestBase):
|
||||
|
||||
# 'once', None
|
||||
run_setup_bind = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_define_tables = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_create_tables = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_inserts = "each"
|
||||
|
||||
# 'each', None
|
||||
run_deletes = "each"
|
||||
|
||||
# 'once', None
|
||||
run_dispose_bind = None
|
||||
|
||||
bind = None
|
||||
_tables_metadata = None
|
||||
tables = None
|
||||
other = None
|
||||
sequences = None
|
||||
|
||||
@config.fixture(autouse=True, scope="class")
|
||||
def _setup_tables_test_class(self):
|
||||
cls = self.__class__
|
||||
cls._init_class()
|
||||
|
||||
cls._setup_once_tables()
|
||||
|
||||
cls._setup_once_inserts()
|
||||
|
||||
yield
|
||||
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _setup_tables_test_instance(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_inserts()
|
||||
|
||||
yield
|
||||
|
||||
self._teardown_each_tables()
|
||||
|
||||
@property
|
||||
def tables_test_metadata(self):
|
||||
return self._tables_metadata
|
||||
|
||||
@classmethod
|
||||
def _init_class(cls):
|
||||
if cls.run_define_tables == "each":
|
||||
if cls.run_create_tables == "once":
|
||||
cls.run_create_tables = "each"
|
||||
assert cls.run_inserts in ("each", None)
|
||||
|
||||
cls.other = adict()
|
||||
cls.tables = adict()
|
||||
cls.sequences = adict()
|
||||
|
||||
cls.bind = cls.setup_bind()
|
||||
cls._tables_metadata = sa.MetaData()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_inserts(cls):
|
||||
if cls.run_inserts == "once":
|
||||
cls._load_fixtures()
|
||||
with cls.bind.begin() as conn:
|
||||
cls.insert_data(conn)
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
if cls.run_define_tables == "once":
|
||||
cls.define_tables(cls._tables_metadata)
|
||||
if cls.run_create_tables == "once":
|
||||
cls._tables_metadata.create_all(cls.bind)
|
||||
cls.tables.update(cls._tables_metadata.tables)
|
||||
cls.sequences.update(cls._tables_metadata._sequences)
|
||||
|
||||
def _setup_each_tables(self):
|
||||
if self.run_define_tables == "each":
|
||||
self.define_tables(self._tables_metadata)
|
||||
if self.run_create_tables == "each":
|
||||
self._tables_metadata.create_all(self.bind)
|
||||
self.tables.update(self._tables_metadata.tables)
|
||||
self.sequences.update(self._tables_metadata._sequences)
|
||||
elif self.run_create_tables == "each":
|
||||
self._tables_metadata.create_all(self.bind)
|
||||
|
||||
def _setup_each_inserts(self):
|
||||
if self.run_inserts == "each":
|
||||
self._load_fixtures()
|
||||
with self.bind.begin() as conn:
|
||||
self.insert_data(conn)
|
||||
|
||||
def _teardown_each_tables(self):
|
||||
if self.run_define_tables == "each":
|
||||
self.tables.clear()
|
||||
if self.run_create_tables == "each":
|
||||
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
|
||||
self._tables_metadata.clear()
|
||||
elif self.run_create_tables == "each":
|
||||
drop_all_tables_from_metadata(self._tables_metadata, self.bind)
|
||||
|
||||
savepoints = getattr(config.requirements, "savepoints", False)
|
||||
if savepoints:
|
||||
savepoints = savepoints.enabled
|
||||
|
||||
# no need to run deletes if tables are recreated on setup
|
||||
if (
|
||||
self.run_define_tables != "each"
|
||||
and self.run_create_tables != "each"
|
||||
and self.run_deletes == "each"
|
||||
):
|
||||
with self.bind.begin() as conn:
|
||||
for table in reversed(
|
||||
[
|
||||
t
|
||||
for (t, fks) in sort_tables_and_constraints(
|
||||
self._tables_metadata.tables.values()
|
||||
)
|
||||
if t is not None
|
||||
]
|
||||
):
|
||||
try:
|
||||
if savepoints:
|
||||
with conn.begin_nested():
|
||||
conn.execute(table.delete())
|
||||
else:
|
||||
conn.execute(table.delete())
|
||||
except sa.exc.DBAPIError as ex:
|
||||
util.print_(
|
||||
("Error emptying table %s: %r" % (table, ex)),
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_metadata_bind(cls):
|
||||
if cls.run_create_tables:
|
||||
drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
|
||||
|
||||
if cls.run_dispose_bind == "once":
|
||||
cls.dispose_bind(cls.bind)
|
||||
|
||||
cls._tables_metadata.bind = None
|
||||
|
||||
if cls.run_setup_bind is not None:
|
||||
cls.bind = None
|
||||
|
||||
@classmethod
|
||||
def setup_bind(cls):
|
||||
return config.db
|
||||
|
||||
@classmethod
|
||||
def dispose_bind(cls, bind):
|
||||
if hasattr(bind, "dispose"):
|
||||
bind.dispose()
|
||||
elif hasattr(bind, "close"):
|
||||
bind.close()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def fixtures(cls):
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
pass
|
||||
|
||||
def sql_count_(self, count, fn):
|
||||
self.assert_sql_count(self.bind, fn, count)
|
||||
|
||||
def sql_eq_(self, callable_, statements):
|
||||
self.assert_sql(self.bind, callable_, statements)
|
||||
|
||||
@classmethod
|
||||
def _load_fixtures(cls):
|
||||
"""Insert rows as represented by the fixtures() method."""
|
||||
headers, rows = {}, {}
|
||||
for table, data in cls.fixtures().items():
|
||||
if len(data) < 2:
|
||||
continue
|
||||
if isinstance(table, util.string_types):
|
||||
table = cls.tables[table]
|
||||
headers[table] = data[0]
|
||||
rows[table] = data[1:]
|
||||
for table, fks in sort_tables_and_constraints(
|
||||
cls._tables_metadata.tables.values()
|
||||
):
|
||||
if table is None:
|
||||
continue
|
||||
if table not in headers:
|
||||
continue
|
||||
with cls.bind.begin() as conn:
|
||||
conn.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(zip(headers[table], column_values))
|
||||
for column_values in rows[table]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class NoCache(object):
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _disable_cache(self):
|
||||
_cache = config.db._compiled_cache
|
||||
config.db._compiled_cache = None
|
||||
yield
|
||||
config.db._compiled_cache = _cache
|
||||
|
||||
|
||||
class RemovesEvents(object):
|
||||
@util.memoized_property
|
||||
def _event_fns(self):
|
||||
return set()
|
||||
|
||||
def event_listen(self, target, name, fn, **kw):
|
||||
self._event_fns.add((target, name, fn))
|
||||
event.listen(target, name, fn, **kw)
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _remove_events(self):
|
||||
yield
|
||||
for key in self._event_fns:
|
||||
event.remove(*key)
|
||||
|
||||
|
||||
_fixture_sessions = set()
|
||||
|
||||
|
||||
def fixture_session(**kw):
|
||||
kw.setdefault("autoflush", True)
|
||||
kw.setdefault("expire_on_commit", True)
|
||||
|
||||
bind = kw.pop("bind", config.db)
|
||||
|
||||
sess = sa.orm.Session(bind, **kw)
|
||||
_fixture_sessions.add(sess)
|
||||
return sess
|
||||
|
||||
|
||||
def _close_all_sessions():
|
||||
# will close all still-referenced sessions
|
||||
sa.orm.session.close_all_sessions()
|
||||
_fixture_sessions.clear()
|
||||
|
||||
|
||||
def stop_test_class_inside_fixtures(cls):
|
||||
_close_all_sessions()
|
||||
sa.orm.clear_mappers()
|
||||
|
||||
|
||||
def after_test():
|
||||
if _fixture_sessions:
|
||||
_close_all_sessions()
|
||||
|
||||
|
||||
class ORMTest(TestBase):
|
||||
pass
|
||||
|
||||
|
||||
class MappedTest(TablesTest, assertions.AssertsExecutionResults):
|
||||
# 'once', 'each', None
|
||||
run_setup_classes = "once"
|
||||
|
||||
# 'once', 'each', None
|
||||
run_setup_mappers = "each"
|
||||
|
||||
classes = None
|
||||
|
||||
@config.fixture(autouse=True, scope="class")
|
||||
def _setup_tables_test_class(self):
|
||||
cls = self.__class__
|
||||
cls._init_class()
|
||||
|
||||
if cls.classes is None:
|
||||
cls.classes = adict()
|
||||
|
||||
cls._setup_once_tables()
|
||||
cls._setup_once_classes()
|
||||
cls._setup_once_mappers()
|
||||
cls._setup_once_inserts()
|
||||
|
||||
yield
|
||||
|
||||
cls._teardown_once_class()
|
||||
cls._teardown_once_metadata_bind()
|
||||
|
||||
@config.fixture(autouse=True, scope="function")
|
||||
def _setup_tables_test_instance(self):
|
||||
self._setup_each_tables()
|
||||
self._setup_each_classes()
|
||||
self._setup_each_mappers()
|
||||
self._setup_each_inserts()
|
||||
|
||||
yield
|
||||
|
||||
sa.orm.session.close_all_sessions()
|
||||
self._teardown_each_mappers()
|
||||
self._teardown_each_classes()
|
||||
self._teardown_each_tables()
|
||||
|
||||
@classmethod
|
||||
def _teardown_once_class(cls):
|
||||
cls.classes.clear()
|
||||
|
||||
@classmethod
|
||||
def _setup_once_classes(cls):
|
||||
if cls.run_setup_classes == "once":
|
||||
cls._with_register_classes(cls.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _setup_once_mappers(cls):
|
||||
if cls.run_setup_mappers == "once":
|
||||
cls.mapper_registry, cls.mapper = cls._generate_registry()
|
||||
cls._with_register_classes(cls.setup_mappers)
|
||||
|
||||
def _setup_each_mappers(self):
|
||||
if self.run_setup_mappers != "once":
|
||||
(
|
||||
self.__class__.mapper_registry,
|
||||
self.__class__.mapper,
|
||||
) = self._generate_registry()
|
||||
|
||||
if self.run_setup_mappers == "each":
|
||||
self._with_register_classes(self.setup_mappers)
|
||||
|
||||
def _setup_each_classes(self):
|
||||
if self.run_setup_classes == "each":
|
||||
self._with_register_classes(self.setup_classes)
|
||||
|
||||
@classmethod
|
||||
def _generate_registry(cls):
|
||||
decl = registry(metadata=cls._tables_metadata)
|
||||
return decl, decl.map_imperatively
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
"""Run a setup method, framing the operation with a Base class
|
||||
that will catch new subclasses to be established within
|
||||
the "classes" registry.
|
||||
|
||||
"""
|
||||
cls_registry = cls.classes
|
||||
|
||||
assert cls_registry is not None
|
||||
|
||||
class FindFixture(type):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
cls_registry[classname] = cls
|
||||
type.__init__(cls, classname, bases, dict_)
|
||||
|
||||
class _Base(util.with_metaclass(FindFixture, object)):
|
||||
pass
|
||||
|
||||
class Basic(BasicEntity, _Base):
|
||||
pass
|
||||
|
||||
class Comparable(ComparableEntity, _Base):
|
||||
pass
|
||||
|
||||
cls.Basic = Basic
|
||||
cls.Comparable = Comparable
|
||||
fn()
|
||||
|
||||
def _teardown_each_mappers(self):
|
||||
# some tests create mappers in the test bodies
|
||||
# and will define setup_mappers as None -
|
||||
# clear mappers in any case
|
||||
if self.run_setup_mappers != "once":
|
||||
sa.orm.clear_mappers()
|
||||
|
||||
def _teardown_each_classes(self):
|
||||
if self.run_setup_classes != "once":
|
||||
self.classes.clear()
|
||||
|
||||
@classmethod
|
||||
def setup_classes(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_mappers(cls):
|
||||
pass
|
||||
|
||||
|
||||
class DeclarativeMappedTest(MappedTest):
|
||||
run_setup_classes = "once"
|
||||
run_setup_mappers = "once"
|
||||
|
||||
@classmethod
|
||||
def _setup_once_tables(cls):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _with_register_classes(cls, fn):
|
||||
cls_registry = cls.classes
|
||||
|
||||
class FindFixtureDeclarative(DeclarativeMeta):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
cls_registry[classname] = cls
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_)
|
||||
|
||||
class DeclarativeBasic(object):
|
||||
__table_cls__ = schema.Table
|
||||
|
||||
_DeclBase = declarative_base(
|
||||
metadata=cls._tables_metadata,
|
||||
metaclass=FindFixtureDeclarative,
|
||||
cls=DeclarativeBasic,
|
||||
)
|
||||
|
||||
cls.DeclarativeBasic = _DeclBase
|
||||
|
||||
# sets up cls.Basic which is helpful for things like composite
|
||||
# classes
|
||||
super(DeclarativeMappedTest, cls)._with_register_classes(fn)
|
||||
|
||||
if cls._tables_metadata.tables and cls.run_create_tables:
|
||||
cls._tables_metadata.create_all(config.db)
|
||||
|
||||
|
||||
class ComputedReflectionFixtureTest(TablesTest):
|
||||
run_inserts = run_deletes = None
|
||||
|
||||
__backend__ = True
|
||||
__requires__ = ("computed_columns", "table_reflection")
|
||||
|
||||
regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
|
||||
|
||||
def normalize(self, text):
|
||||
return self.regexp.sub("", text).lower()
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
from .. import Integer
|
||||
from .. import testing
|
||||
from ..schema import Column
|
||||
from ..schema import Computed
|
||||
from ..schema import Table
|
||||
|
||||
Table(
|
||||
"computed_default_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_col", Integer, Computed("normal + 42")),
|
||||
Column("with_default", Integer, server_default="42"),
|
||||
)
|
||||
|
||||
t = Table(
|
||||
"computed_column_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_no_flag", Integer, Computed("normal + 42")),
|
||||
)
|
||||
|
||||
if testing.requires.schemas.enabled:
|
||||
t2 = Table(
|
||||
"computed_column_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("normal", Integer),
|
||||
Column("computed_no_flag", Integer, Computed("normal / 42")),
|
||||
schema=config.test_schema,
|
||||
)
|
||||
|
||||
if testing.requires.computed_columns_virtual.enabled:
|
||||
t.append_column(
|
||||
Column(
|
||||
"computed_virtual",
|
||||
Integer,
|
||||
Computed("normal + 2", persisted=False),
|
||||
)
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
t2.append_column(
|
||||
Column(
|
||||
"computed_virtual",
|
||||
Integer,
|
||||
Computed("normal / 2", persisted=False),
|
||||
)
|
||||
)
|
||||
if testing.requires.computed_columns_stored.enabled:
|
||||
t.append_column(
|
||||
Column(
|
||||
"computed_stored",
|
||||
Integer,
|
||||
Computed("normal - 42", persisted=True),
|
||||
)
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
t2.append_column(
|
||||
Column(
|
||||
"computed_stored",
|
||||
Integer,
|
||||
Computed("normal * 42", persisted=True),
|
||||
)
|
||||
)
|
||||
32
lib/sqlalchemy/testing/mock.py
Normal file
32
lib/sqlalchemy/testing/mock.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# testing/mock.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Import stub for mock library.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
|
||||
from ..util import py3k
|
||||
|
||||
|
||||
if py3k:
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import call
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import ANY
|
||||
else:
|
||||
try:
|
||||
from mock import MagicMock # noqa
|
||||
from mock import Mock # noqa
|
||||
from mock import call # noqa
|
||||
from mock import patch # noqa
|
||||
from mock import ANY # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"SQLAlchemy's test suite requires the "
|
||||
"'mock' library as of 0.8.2."
|
||||
)
|
||||
151
lib/sqlalchemy/testing/pickleable.py
Normal file
151
lib/sqlalchemy/testing/pickleable.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# testing/pickleable.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Classes used in pickling tests, need to be at the module level for
|
||||
unpickling.
|
||||
"""
|
||||
|
||||
from . import fixtures
|
||||
from ..schema import Column
|
||||
from ..types import String
|
||||
|
||||
|
||||
class User(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Order(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Dingaling(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class EmailUser(User):
|
||||
pass
|
||||
|
||||
|
||||
class Address(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: these are kind of arbitrary....
|
||||
class Child1(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Child2(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Parent(fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Screen(object):
|
||||
def __init__(self, obj, parent=None):
|
||||
self.obj = obj
|
||||
self.parent = parent
|
||||
|
||||
|
||||
class Mixin(object):
|
||||
email_address = Column(String)
|
||||
|
||||
|
||||
class AddressWMixin(Mixin, fixtures.ComparableEntity):
|
||||
pass
|
||||
|
||||
|
||||
class Foo(object):
|
||||
def __init__(self, moredata, stuff="im stuff"):
|
||||
self.data = "im data"
|
||||
self.stuff = stuff
|
||||
self.moredata = moredata
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.data == self.data
|
||||
and other.stuff == self.stuff
|
||||
and other.moredata == self.moredata
|
||||
)
|
||||
|
||||
|
||||
class Bar(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
__hash__ = object.__hash__
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.__class__ is self.__class__
|
||||
and other.x == self.x
|
||||
and other.y == self.y
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class OldSchool:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
other.__class__ is self.__class__
|
||||
and other.x == self.x
|
||||
and other.y == self.y
|
||||
)
|
||||
|
||||
|
||||
class OldSchoolWithoutCompare:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
|
||||
class BarWithoutCompare(object):
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __str__(self):
|
||||
return "Bar(%d, %d)" % (self.x, self.y)
|
||||
|
||||
|
||||
class NotComparable(object):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
class BrokenComparable(object):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
raise NotImplementedError
|
||||
|
||||
def __ne__(self, other):
|
||||
raise NotImplementedError
|
||||
0
lib/sqlalchemy/testing/plugin/__init__.py
Normal file
0
lib/sqlalchemy/testing/plugin/__init__.py
Normal file
54
lib/sqlalchemy/testing/plugin/bootstrap.py
Normal file
54
lib/sqlalchemy/testing/plugin/bootstrap.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
Bootstrapper for test framework plugins.
|
||||
|
||||
The entire rationale for this system is to get the modules in plugin/
|
||||
imported without importing all of the supporting library, so that we can
|
||||
set up things for testing before coverage starts.
|
||||
|
||||
The rationale for all of plugin/ being *in* the supporting library in the
|
||||
first place is so that the testing and plugin suite is available to other
|
||||
libraries, mainly external SQLAlchemy and Alembic dialects, to make use
|
||||
of the same test environment and standard suites available to
|
||||
SQLAlchemy/Alembic themselves without the need to ship/install a separate
|
||||
package outside of SQLAlchemy.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
bootstrap_file = locals()["bootstrap_file"]
|
||||
to_bootstrap = locals()["to_bootstrap"]
|
||||
|
||||
|
||||
def load_file_as_module(name):
|
||||
path = os.path.join(os.path.dirname(bootstrap_file), "%s.py" % name)
|
||||
|
||||
if sys.version_info >= (3, 5):
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
assert spec is not None
|
||||
assert spec.loader is not None
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
else:
|
||||
import imp
|
||||
|
||||
mod = imp.load_source(name, path)
|
||||
|
||||
return mod
|
||||
|
||||
|
||||
if to_bootstrap == "pytest":
|
||||
sys.modules["sqla_plugin_base"] = load_file_as_module("plugin_base")
|
||||
sys.modules["sqla_plugin_base"].bootstrapped_as_sqlalchemy = True
|
||||
if sys.version_info < (3, 0):
|
||||
sys.modules["sqla_reinvent_fixtures"] = load_file_as_module(
|
||||
"reinvent_fixtures_py2k"
|
||||
)
|
||||
sys.modules["sqla_pytestplugin"] = load_file_as_module("pytestplugin")
|
||||
else:
|
||||
raise Exception("unknown bootstrap: %s" % to_bootstrap) # noqa
|
||||
789
lib/sqlalchemy/testing/plugin/plugin_base.py
Normal file
789
lib/sqlalchemy/testing/plugin/plugin_base.py
Normal file
@@ -0,0 +1,789 @@
|
||||
# plugin/plugin_base.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Testing extensions.
|
||||
|
||||
this module is designed to work as a testing-framework-agnostic library,
|
||||
created so that multiple test frameworks can be supported at once
|
||||
(mostly so that we can migrate to new ones). The current target
|
||||
is pytest.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
|
||||
# flag which indicates we are in the SQLAlchemy testing suite,
|
||||
# and not that of Alembic or a third party dialect.
|
||||
bootstrapped_as_sqlalchemy = False
|
||||
|
||||
log = logging.getLogger("sqlalchemy.testing.plugin_base")
|
||||
|
||||
|
||||
py3k = sys.version_info >= (3, 0)
|
||||
|
||||
if py3k:
|
||||
import configparser
|
||||
|
||||
ABC = abc.ABC
|
||||
else:
|
||||
import ConfigParser as configparser
|
||||
import collections as collections_abc # noqa
|
||||
|
||||
class ABC(object):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
|
||||
# late imports
|
||||
fixtures = None
|
||||
engines = None
|
||||
exclusions = None
|
||||
warnings = None
|
||||
profiling = None
|
||||
provision = None
|
||||
assertions = None
|
||||
requirements = None
|
||||
config = None
|
||||
testing = None
|
||||
util = None
|
||||
file_config = None
|
||||
|
||||
logging = None
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
options = None
|
||||
|
||||
|
||||
def setup_options(make_option):
|
||||
make_option(
|
||||
"--log-info",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_log,
|
||||
help="turn on info logging for <LOG> (multiple OK)",
|
||||
)
|
||||
make_option(
|
||||
"--log-debug",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_log,
|
||||
help="turn on debug logging for <LOG> (multiple OK)",
|
||||
)
|
||||
make_option(
|
||||
"--db",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="db",
|
||||
help="Use prefab database uri. Multiple OK, "
|
||||
"first one is run by default.",
|
||||
)
|
||||
make_option(
|
||||
"--dbs",
|
||||
action="callback",
|
||||
zeroarg_callback=_list_dbs,
|
||||
help="List available prefab dbs",
|
||||
)
|
||||
make_option(
|
||||
"--dburi",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="dburi",
|
||||
help="Database uri. Multiple OK, " "first one is run by default.",
|
||||
)
|
||||
make_option(
|
||||
"--dbdriver",
|
||||
action="append",
|
||||
type=str,
|
||||
dest="dbdriver",
|
||||
help="Additional database drivers to include in tests. "
|
||||
"These are linked to the existing database URLs by the "
|
||||
"provisioning system.",
|
||||
)
|
||||
make_option(
|
||||
"--dropfirst",
|
||||
action="store_true",
|
||||
dest="dropfirst",
|
||||
help="Drop all tables in the target database first",
|
||||
)
|
||||
make_option(
|
||||
"--disable-asyncio",
|
||||
action="store_true",
|
||||
help="disable test / fixtures / provisoning running in asyncio",
|
||||
)
|
||||
make_option(
|
||||
"--backend-only",
|
||||
action="store_true",
|
||||
dest="backend_only",
|
||||
help="Run only tests marked with __backend__ or __sparse_backend__",
|
||||
)
|
||||
make_option(
|
||||
"--nomemory",
|
||||
action="store_true",
|
||||
dest="nomemory",
|
||||
help="Don't run memory profiling tests",
|
||||
)
|
||||
make_option(
|
||||
"--notimingintensive",
|
||||
action="store_true",
|
||||
dest="notimingintensive",
|
||||
help="Don't run timing intensive tests",
|
||||
)
|
||||
make_option(
|
||||
"--profile-sort",
|
||||
type=str,
|
||||
default="cumulative",
|
||||
dest="profilesort",
|
||||
help="Type of sort for profiling standard output",
|
||||
)
|
||||
make_option(
|
||||
"--profile-dump",
|
||||
type=str,
|
||||
dest="profiledump",
|
||||
help="Filename where a single profile run will be dumped",
|
||||
)
|
||||
make_option(
|
||||
"--postgresql-templatedb",
|
||||
type=str,
|
||||
help="name of template database to use for PostgreSQL "
|
||||
"CREATE DATABASE (defaults to current database)",
|
||||
)
|
||||
make_option(
|
||||
"--low-connections",
|
||||
action="store_true",
|
||||
dest="low_connections",
|
||||
help="Use a low number of distinct connections - "
|
||||
"i.e. for Oracle TNS",
|
||||
)
|
||||
make_option(
|
||||
"--write-idents",
|
||||
type=str,
|
||||
dest="write_idents",
|
||||
help="write out generated follower idents to <file>, "
|
||||
"when -n<num> is used",
|
||||
)
|
||||
make_option(
|
||||
"--reversetop",
|
||||
action="store_true",
|
||||
dest="reversetop",
|
||||
default=False,
|
||||
help="Use a random-ordering set implementation in the ORM "
|
||||
"(helps reveal dependency issues)",
|
||||
)
|
||||
make_option(
|
||||
"--requirements",
|
||||
action="callback",
|
||||
type=str,
|
||||
callback=_requirements_opt,
|
||||
help="requirements class for testing, overrides setup.cfg",
|
||||
)
|
||||
make_option(
|
||||
"--with-cdecimal",
|
||||
action="store_true",
|
||||
dest="cdecimal",
|
||||
default=False,
|
||||
help="Monkeypatch the cdecimal library into Python 'decimal' "
|
||||
"for all tests",
|
||||
)
|
||||
make_option(
|
||||
"--include-tag",
|
||||
action="callback",
|
||||
callback=_include_tag,
|
||||
type=str,
|
||||
help="Include tests with tag <tag>",
|
||||
)
|
||||
make_option(
|
||||
"--exclude-tag",
|
||||
action="callback",
|
||||
callback=_exclude_tag,
|
||||
type=str,
|
||||
help="Exclude tests with tag <tag>",
|
||||
)
|
||||
make_option(
|
||||
"--write-profiles",
|
||||
action="store_true",
|
||||
dest="write_profiles",
|
||||
default=False,
|
||||
help="Write/update failing profiling data.",
|
||||
)
|
||||
make_option(
|
||||
"--force-write-profiles",
|
||||
action="store_true",
|
||||
dest="force_write_profiles",
|
||||
default=False,
|
||||
help="Unconditionally write/update profiling data.",
|
||||
)
|
||||
make_option(
|
||||
"--dump-pyannotate",
|
||||
type=str,
|
||||
dest="dump_pyannotate",
|
||||
help="Run pyannotate and dump json info to given file",
|
||||
)
|
||||
make_option(
|
||||
"--mypy-extra-test-path",
|
||||
type=str,
|
||||
action="append",
|
||||
default=[],
|
||||
dest="mypy_extra_test_paths",
|
||||
help="Additional test directories to add to the mypy tests. "
|
||||
"This is used only when running mypy tests. Multiple OK",
|
||||
)
|
||||
|
||||
|
||||
def configure_follower(follower_ident):
|
||||
"""Configure required state for a follower.
|
||||
|
||||
This invokes in the parent process and typically includes
|
||||
database creation.
|
||||
|
||||
"""
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
provision.FOLLOWER_IDENT = follower_ident
|
||||
|
||||
|
||||
def memoize_important_follower_config(dict_):
|
||||
"""Store important configuration we will need to send to a follower.
|
||||
|
||||
This invokes in the parent process after normal config is set up.
|
||||
|
||||
This is necessary as pytest seems to not be using forking, so we
|
||||
start with nothing in memory, *but* it isn't running our argparse
|
||||
callables, so we have to just copy all of that over.
|
||||
|
||||
"""
|
||||
dict_["memoized_config"] = {
|
||||
"include_tags": include_tags,
|
||||
"exclude_tags": exclude_tags,
|
||||
}
|
||||
|
||||
|
||||
def restore_important_follower_config(dict_):
|
||||
"""Restore important configuration needed by a follower.
|
||||
|
||||
This invokes in the follower process.
|
||||
|
||||
"""
|
||||
global include_tags, exclude_tags
|
||||
include_tags.update(dict_["memoized_config"]["include_tags"])
|
||||
exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
|
||||
|
||||
|
||||
def read_config():
|
||||
global file_config
|
||||
file_config = configparser.ConfigParser()
|
||||
file_config.read(["setup.cfg", "test.cfg"])
|
||||
|
||||
|
||||
def pre_begin(opt):
|
||||
"""things to set up early, before coverage might be setup."""
|
||||
global options
|
||||
options = opt
|
||||
for fn in pre_configure:
|
||||
fn(options, file_config)
|
||||
|
||||
|
||||
def set_coverage_flag(value):
|
||||
options.has_coverage = value
|
||||
|
||||
|
||||
def post_begin():
|
||||
"""things to set up later, once we know coverage is running."""
|
||||
# Lazy setup of other options (post coverage)
|
||||
for fn in post_configure:
|
||||
fn(options, file_config)
|
||||
|
||||
# late imports, has to happen after config.
|
||||
global util, fixtures, engines, exclusions, assertions, provision
|
||||
global warnings, profiling, config, testing
|
||||
from sqlalchemy import testing # noqa
|
||||
from sqlalchemy.testing import fixtures, engines, exclusions # noqa
|
||||
from sqlalchemy.testing import assertions, warnings, profiling # noqa
|
||||
from sqlalchemy.testing import config, provision # noqa
|
||||
from sqlalchemy import util # noqa
|
||||
|
||||
warnings.setup_filters()
|
||||
|
||||
|
||||
def _log(opt_str, value, parser):
|
||||
global logging
|
||||
if not logging:
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
|
||||
if opt_str.endswith("-info"):
|
||||
logging.getLogger(value).setLevel(logging.INFO)
|
||||
elif opt_str.endswith("-debug"):
|
||||
logging.getLogger(value).setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def _list_dbs(*args):
|
||||
print("Available --db options (use --dburi to override)")
|
||||
for macro in sorted(file_config.options("db")):
|
||||
print("%20s\t%s" % (macro, file_config.get("db", macro)))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _requirements_opt(opt_str, value, parser):
|
||||
_setup_requirements(value)
|
||||
|
||||
|
||||
def _exclude_tag(opt_str, value, parser):
|
||||
exclude_tags.add(value.replace("-", "_"))
|
||||
|
||||
|
||||
def _include_tag(opt_str, value, parser):
|
||||
include_tags.add(value.replace("-", "_"))
|
||||
|
||||
|
||||
pre_configure = []
|
||||
post_configure = []
|
||||
|
||||
|
||||
def pre(fn):
|
||||
pre_configure.append(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def post(fn):
|
||||
post_configure.append(fn)
|
||||
return fn
|
||||
|
||||
|
||||
@pre
|
||||
def _setup_options(opt, file_config):
|
||||
global options
|
||||
options = opt
|
||||
|
||||
|
||||
@pre
|
||||
def _set_nomemory(opt, file_config):
|
||||
if opt.nomemory:
|
||||
exclude_tags.add("memory_intensive")
|
||||
|
||||
|
||||
@pre
|
||||
def _set_notimingintensive(opt, file_config):
|
||||
if opt.notimingintensive:
|
||||
exclude_tags.add("timing_intensive")
|
||||
|
||||
|
||||
@pre
|
||||
def _monkeypatch_cdecimal(options, file_config):
|
||||
if options.cdecimal:
|
||||
import cdecimal
|
||||
|
||||
sys.modules["decimal"] = cdecimal
|
||||
|
||||
|
||||
@post
|
||||
def _init_symbols(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
config._fixture_functions = _fixture_fn_class()
|
||||
|
||||
|
||||
@post
|
||||
def _set_disable_asyncio(opt, file_config):
|
||||
if opt.disable_asyncio or not py3k:
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio.ENABLE_ASYNCIO = False
|
||||
|
||||
|
||||
@post
|
||||
def _engine_uri(options, file_config):
|
||||
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
if options.dburi:
|
||||
db_urls = list(options.dburi)
|
||||
else:
|
||||
db_urls = []
|
||||
|
||||
extra_drivers = options.dbdriver or []
|
||||
|
||||
if options.db:
|
||||
for db_token in options.db:
|
||||
for db in re.split(r"[,\s]+", db_token):
|
||||
if db not in file_config.options("db"):
|
||||
raise RuntimeError(
|
||||
"Unknown URI specifier '%s'. "
|
||||
"Specify --dbs for known uris." % db
|
||||
)
|
||||
else:
|
||||
db_urls.append(file_config.get("db", db))
|
||||
|
||||
if not db_urls:
|
||||
db_urls.append(file_config.get("db", "default"))
|
||||
|
||||
config._current = None
|
||||
|
||||
expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
|
||||
|
||||
for db_url in expanded_urls:
|
||||
log.info("Adding database URL: %s", db_url)
|
||||
|
||||
if options.write_idents and provision.FOLLOWER_IDENT:
|
||||
with open(options.write_idents, "a") as file_:
|
||||
file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
|
||||
|
||||
cfg = provision.setup_config(
|
||||
db_url, options, file_config, provision.FOLLOWER_IDENT
|
||||
)
|
||||
if not config._current:
|
||||
cfg.set_as_current(cfg, testing)
|
||||
|
||||
|
||||
@post
|
||||
def _requirements(options, file_config):
|
||||
|
||||
requirement_cls = file_config.get("sqla_testing", "requirement_cls")
|
||||
_setup_requirements(requirement_cls)
|
||||
|
||||
|
||||
def _setup_requirements(argument):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy import testing
|
||||
|
||||
if config.requirements is not None:
|
||||
return
|
||||
|
||||
modname, clsname = argument.split(":")
|
||||
|
||||
# importlib.import_module() only introduced in 2.7, a little
|
||||
# late
|
||||
mod = __import__(modname)
|
||||
for component in modname.split(".")[1:]:
|
||||
mod = getattr(mod, component)
|
||||
req_cls = getattr(mod, clsname)
|
||||
|
||||
config.requirements = testing.requires = req_cls()
|
||||
|
||||
config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
|
||||
|
||||
|
||||
@post
|
||||
def _prep_testing_database(options, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
if options.dropfirst:
|
||||
from sqlalchemy.testing import provision
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
provision.drop_all_schema_objects(cfg, cfg.db)
|
||||
|
||||
|
||||
@post
|
||||
def _reverse_topological(options, file_config):
|
||||
if options.reversetop:
|
||||
from sqlalchemy.orm.util import randomize_unitofwork
|
||||
|
||||
randomize_unitofwork()
|
||||
|
||||
|
||||
@post
|
||||
def _post_setup_options(opt, file_config):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
config.options = options
|
||||
config.file_config = file_config
|
||||
|
||||
|
||||
@post
|
||||
def _setup_profiling(options, file_config):
|
||||
from sqlalchemy.testing import profiling
|
||||
|
||||
profiling._profile_stats = profiling.ProfileStatsFile(
|
||||
file_config.get("sqla_testing", "profile_file"),
|
||||
sort=options.profilesort,
|
||||
dump=options.profiledump,
|
||||
)
|
||||
|
||||
|
||||
def want_class(name, cls):
|
||||
if not issubclass(cls, fixtures.TestBase):
|
||||
return False
|
||||
elif name.startswith("_"):
|
||||
return False
|
||||
elif (
|
||||
config.options.backend_only
|
||||
and not getattr(cls, "__backend__", False)
|
||||
and not getattr(cls, "__sparse_backend__", False)
|
||||
and not getattr(cls, "__only_on__", False)
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def want_method(cls, fn):
|
||||
if not fn.__name__.startswith("test_"):
|
||||
return False
|
||||
elif fn.__module__ is None:
|
||||
return False
|
||||
elif include_tags:
|
||||
return (
|
||||
hasattr(cls, "__tags__")
|
||||
and exclusions.tags(cls.__tags__).include_test(
|
||||
include_tags, exclude_tags
|
||||
)
|
||||
) or (
|
||||
hasattr(fn, "_sa_exclusion_extend")
|
||||
and fn._sa_exclusion_extend.include_test(
|
||||
include_tags, exclude_tags
|
||||
)
|
||||
)
|
||||
elif exclude_tags and hasattr(cls, "__tags__"):
|
||||
return exclusions.tags(cls.__tags__).include_test(
|
||||
include_tags, exclude_tags
|
||||
)
|
||||
elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
|
||||
return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def generate_sub_tests(cls, module):
|
||||
if getattr(cls, "__backend__", False) or getattr(
|
||||
cls, "__sparse_backend__", False
|
||||
):
|
||||
sparse = getattr(cls, "__sparse_backend__", False)
|
||||
for cfg in _possible_configs_for_cls(cls, sparse=sparse):
|
||||
orig_name = cls.__name__
|
||||
|
||||
# we can have special chars in these names except for the
|
||||
# pytest junit plugin, which is tripped up by the brackets
|
||||
# and periods, so sanitize
|
||||
|
||||
alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
|
||||
alpha_name = re.sub(r"_+$", "", alpha_name)
|
||||
name = "%s_%s" % (cls.__name__, alpha_name)
|
||||
subcls = type(
|
||||
name,
|
||||
(cls,),
|
||||
{"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
|
||||
)
|
||||
setattr(module, name, subcls)
|
||||
yield subcls
|
||||
else:
|
||||
yield cls
|
||||
|
||||
|
||||
def start_test_class_outside_fixtures(cls):
|
||||
_do_skips(cls)
|
||||
_setup_engine(cls)
|
||||
|
||||
|
||||
def stop_test_class(cls):
|
||||
# close sessions, immediate connections, etc.
|
||||
fixtures.stop_test_class_inside_fixtures(cls)
|
||||
|
||||
# close outstanding connection pool connections, dispose of
|
||||
# additional engines
|
||||
engines.testing_reaper.stop_test_class_inside_fixtures()
|
||||
|
||||
|
||||
def stop_test_class_outside_fixtures(cls):
|
||||
engines.testing_reaper.stop_test_class_outside_fixtures()
|
||||
provision.stop_test_class_outside_fixtures(config, config.db, cls)
|
||||
try:
|
||||
if not options.low_connections:
|
||||
assertions.global_cleanup_assertions()
|
||||
finally:
|
||||
_restore_engine()
|
||||
|
||||
|
||||
def _restore_engine():
|
||||
if config._current:
|
||||
config._current.reset(testing)
|
||||
|
||||
|
||||
def final_process_cleanup():
|
||||
engines.testing_reaper.final_cleanup()
|
||||
assertions.global_cleanup_assertions()
|
||||
_restore_engine()
|
||||
|
||||
|
||||
def _setup_engine(cls):
|
||||
if getattr(cls, "__engine_options__", None):
|
||||
opts = dict(cls.__engine_options__)
|
||||
opts["scope"] = "class"
|
||||
eng = engines.testing_engine(options=opts)
|
||||
config._current.push_engine(eng, testing)
|
||||
|
||||
|
||||
def before_test(test, test_module_name, test_class, test_name):
|
||||
|
||||
# format looks like:
|
||||
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
|
||||
|
||||
name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
|
||||
|
||||
id_ = "%s.%s.%s" % (test_module_name, name, test_name)
|
||||
|
||||
profiling._start_current_test(id_)
|
||||
|
||||
|
||||
def after_test(test):
|
||||
fixtures.after_test()
|
||||
engines.testing_reaper.after_test()
|
||||
|
||||
|
||||
def after_test_fixtures(test):
|
||||
engines.testing_reaper.after_test_outside_fixtures(test)
|
||||
|
||||
|
||||
def _possible_configs_for_cls(cls, reasons=None, sparse=False):
|
||||
all_configs = set(config.Config.all_configs())
|
||||
|
||||
if cls.__unsupported_on__:
|
||||
spec = exclusions.db_spec(*cls.__unsupported_on__)
|
||||
for config_obj in list(all_configs):
|
||||
if spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if getattr(cls, "__only_on__", None):
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
|
||||
for config_obj in list(all_configs):
|
||||
if not spec(config_obj):
|
||||
all_configs.remove(config_obj)
|
||||
|
||||
if getattr(cls, "__only_on_config__", None):
|
||||
all_configs.intersection_update([cls.__only_on_config__])
|
||||
|
||||
if hasattr(cls, "__requires__"):
|
||||
requirements = config.requirements
|
||||
for config_obj in list(all_configs):
|
||||
for requirement in cls.__requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
skip_reasons = check.matching_config_reasons(config_obj)
|
||||
if skip_reasons:
|
||||
all_configs.remove(config_obj)
|
||||
if reasons is not None:
|
||||
reasons.extend(skip_reasons)
|
||||
break
|
||||
|
||||
if hasattr(cls, "__prefer_requires__"):
|
||||
non_preferred = set()
|
||||
requirements = config.requirements
|
||||
for config_obj in list(all_configs):
|
||||
for requirement in cls.__prefer_requires__:
|
||||
check = getattr(requirements, requirement)
|
||||
|
||||
if not check.enabled_for_config(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
if sparse:
|
||||
# pick only one config from each base dialect
|
||||
# sorted so we get the same backend each time selecting the highest
|
||||
# server version info.
|
||||
per_dialect = {}
|
||||
for cfg in reversed(
|
||||
sorted(
|
||||
all_configs,
|
||||
key=lambda cfg: (
|
||||
cfg.db.name,
|
||||
cfg.db.driver,
|
||||
cfg.db.dialect.server_version_info,
|
||||
),
|
||||
)
|
||||
):
|
||||
db = cfg.db.name
|
||||
if db not in per_dialect:
|
||||
per_dialect[db] = cfg
|
||||
return per_dialect.values()
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def _do_skips(cls):
|
||||
reasons = []
|
||||
all_configs = _possible_configs_for_cls(cls, reasons)
|
||||
|
||||
if getattr(cls, "__skip_if__", False):
|
||||
for c in getattr(cls, "__skip_if__"):
|
||||
if c():
|
||||
config.skip_test(
|
||||
"'%s' skipped by %s" % (cls.__name__, c.__name__)
|
||||
)
|
||||
|
||||
if not all_configs:
|
||||
msg = "'%s' unsupported on any DB implementation %s%s" % (
|
||||
cls.__name__,
|
||||
", ".join(
|
||||
"'%s(%s)+%s'"
|
||||
% (
|
||||
config_obj.db.name,
|
||||
".".join(
|
||||
str(dig)
|
||||
for dig in exclusions._server_version(config_obj.db)
|
||||
),
|
||||
config_obj.db.driver,
|
||||
)
|
||||
for config_obj in config.Config.all_configs()
|
||||
),
|
||||
", ".join(reasons),
|
||||
)
|
||||
config.skip_test(msg)
|
||||
elif hasattr(cls, "__prefer_backends__"):
|
||||
non_preferred = set()
|
||||
spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
|
||||
for config_obj in all_configs:
|
||||
if not spec(config_obj):
|
||||
non_preferred.add(config_obj)
|
||||
if all_configs.difference(non_preferred):
|
||||
all_configs.difference_update(non_preferred)
|
||||
|
||||
if config._current not in all_configs:
|
||||
_setup_config(all_configs.pop(), cls)
|
||||
|
||||
|
||||
def _setup_config(config_obj, ctx):
|
||||
config._current.push(config_obj, testing)
|
||||
|
||||
|
||||
class FixtureFunctions(ABC):
|
||||
@abc.abstractmethod
|
||||
def skip_test_exception(self, *arg, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def combinations(self, *args, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def param_ident(self, *args, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def fixture(self, *arg, **kw):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_test_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def mark_base_test_class(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
_fixture_fn_class = None
|
||||
|
||||
|
||||
def set_fixture_functions(fixture_fn_class):
|
||||
global _fixture_fn_class
|
||||
_fixture_fn_class = fixture_fn_class
|
||||
820
lib/sqlalchemy/testing/plugin/pytestplugin.py
Normal file
820
lib/sqlalchemy/testing/plugin/pytestplugin.py
Normal file
@@ -0,0 +1,820 @@
|
||||
try:
|
||||
# installed by bootstrap.py
|
||||
import sqla_plugin_base as plugin_base
|
||||
except ImportError:
|
||||
# assume we're a package, use traditional import
|
||||
from . import plugin_base
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
from functools import update_wrapper
|
||||
import inspect
|
||||
import itertools
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
py2k = sys.version_info < (3, 0)
|
||||
if py2k:
|
||||
try:
|
||||
import sqla_reinvent_fixtures as reinvent_fixtures_py2k
|
||||
except ImportError:
|
||||
from . import reinvent_fixtures_py2k
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
group = parser.getgroup("sqlalchemy")
|
||||
|
||||
def make_option(name, **kw):
|
||||
callback_ = kw.pop("callback", None)
|
||||
if callback_:
|
||||
|
||||
class CallableAction(argparse.Action):
|
||||
def __call__(
|
||||
self, parser, namespace, values, option_string=None
|
||||
):
|
||||
callback_(option_string, values, parser)
|
||||
|
||||
kw["action"] = CallableAction
|
||||
|
||||
zeroarg_callback = kw.pop("zeroarg_callback", None)
|
||||
if zeroarg_callback:
|
||||
|
||||
class CallableAction(argparse.Action):
|
||||
def __init__(
|
||||
self,
|
||||
option_strings,
|
||||
dest,
|
||||
default=False,
|
||||
required=False,
|
||||
help=None, # noqa
|
||||
):
|
||||
super(CallableAction, self).__init__(
|
||||
option_strings=option_strings,
|
||||
dest=dest,
|
||||
nargs=0,
|
||||
const=True,
|
||||
default=default,
|
||||
required=required,
|
||||
help=help,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, parser, namespace, values, option_string=None
|
||||
):
|
||||
zeroarg_callback(option_string, values, parser)
|
||||
|
||||
kw["action"] = CallableAction
|
||||
|
||||
group.addoption(name, **kw)
|
||||
|
||||
plugin_base.setup_options(make_option)
|
||||
plugin_base.read_config()
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
if config.pluginmanager.hasplugin("xdist"):
|
||||
config.pluginmanager.register(XDistHooks())
|
||||
|
||||
if hasattr(config, "workerinput"):
|
||||
plugin_base.restore_important_follower_config(config.workerinput)
|
||||
plugin_base.configure_follower(config.workerinput["follower_ident"])
|
||||
else:
|
||||
if config.option.write_idents and os.path.exists(
|
||||
config.option.write_idents
|
||||
):
|
||||
os.remove(config.option.write_idents)
|
||||
|
||||
plugin_base.pre_begin(config.option)
|
||||
|
||||
plugin_base.set_coverage_flag(
|
||||
bool(getattr(config.option, "cov_source", False))
|
||||
)
|
||||
|
||||
plugin_base.set_fixture_functions(PytestFixtureFunctions)
|
||||
|
||||
if config.option.dump_pyannotate:
|
||||
global DUMP_PYANNOTATE
|
||||
DUMP_PYANNOTATE = True
|
||||
|
||||
|
||||
DUMP_PYANNOTATE = False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def collect_types_fixture():
|
||||
if DUMP_PYANNOTATE:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
collect_types.start()
|
||||
yield
|
||||
if DUMP_PYANNOTATE:
|
||||
collect_types.stop()
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._assume_async(plugin_base.post_begin)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
|
||||
|
||||
if session.config.option.dump_pyannotate:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
collect_types.dump_stats(session.config.option.dump_pyannotate)
|
||||
|
||||
|
||||
def pytest_collection_finish(session):
|
||||
if session.config.option.dump_pyannotate:
|
||||
from pyannotate_runtime import collect_types
|
||||
|
||||
lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
|
||||
|
||||
def _filter(filename):
|
||||
filename = os.path.normpath(os.path.abspath(filename))
|
||||
if "lib/sqlalchemy" not in os.path.commonpath(
|
||||
[filename, lib_sqlalchemy]
|
||||
):
|
||||
return None
|
||||
if "testing" in filename:
|
||||
return None
|
||||
|
||||
return filename
|
||||
|
||||
collect_types.init_types_collection(filter_filename=_filter)
|
||||
|
||||
|
||||
class XDistHooks(object):
|
||||
def pytest_configure_node(self, node):
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# the master for each node fills workerinput dictionary
|
||||
# which pytest-xdist will transfer to the subprocess
|
||||
|
||||
plugin_base.memoize_important_follower_config(node.workerinput)
|
||||
|
||||
node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
|
||||
|
||||
asyncio._maybe_async_provisioning(
|
||||
provision.create_follower_db, node.workerinput["follower_ident"]
|
||||
)
|
||||
|
||||
def pytest_testnodedown(self, node, error):
|
||||
from sqlalchemy.testing import provision
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async_provisioning(
|
||||
provision.drop_follower_db, node.workerinput["follower_ident"]
|
||||
)
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(session, config, items):
|
||||
|
||||
# look for all those classes that specify __backend__ and
|
||||
# expand them out into per-database test cases.
|
||||
|
||||
# this is much easier to do within pytest_pycollect_makeitem, however
|
||||
# pytest is iterating through cls.__dict__ as makeitem is
|
||||
# called which causes a "dictionary changed size" error on py3k.
|
||||
# I'd submit a pullreq for them to turn it into a list first, but
|
||||
# it's to suit the rather odd use case here which is that we are adding
|
||||
# new classes to a module on the fly.
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
rebuilt_items = collections.defaultdict(
|
||||
lambda: collections.defaultdict(list)
|
||||
)
|
||||
|
||||
items[:] = [
|
||||
item
|
||||
for item in items
|
||||
if item.getparent(pytest.Class) is not None
|
||||
and not item.getparent(pytest.Class).name.startswith("_")
|
||||
]
|
||||
|
||||
test_classes = set(item.getparent(pytest.Class) for item in items)
|
||||
|
||||
def collect(element):
|
||||
for inst_or_fn in element.collect():
|
||||
if isinstance(inst_or_fn, pytest.Collector):
|
||||
# no yield from in 2.7
|
||||
for el in collect(inst_or_fn):
|
||||
yield el
|
||||
else:
|
||||
yield inst_or_fn
|
||||
|
||||
def setup_test_classes():
|
||||
for test_class in test_classes:
|
||||
for sub_cls in plugin_base.generate_sub_tests(
|
||||
test_class.cls, test_class.module
|
||||
):
|
||||
if sub_cls is not test_class.cls:
|
||||
per_cls_dict = rebuilt_items[test_class.cls]
|
||||
|
||||
# support pytest 5.4.0 and above pytest.Class.from_parent
|
||||
ctor = getattr(pytest.Class, "from_parent", pytest.Class)
|
||||
module = test_class.getparent(pytest.Module)
|
||||
for fn in collect(
|
||||
ctor(name=sub_cls.__name__, parent=module)
|
||||
):
|
||||
per_cls_dict[fn.name].append(fn)
|
||||
|
||||
# class requirements will sometimes need to access the DB to check
|
||||
# capabilities, so need to do this for async
|
||||
asyncio._maybe_async_provisioning(setup_test_classes)
|
||||
|
||||
newitems = []
|
||||
for item in items:
|
||||
cls_ = item.cls
|
||||
if cls_ in rebuilt_items:
|
||||
newitems.extend(rebuilt_items[cls_][item.name])
|
||||
else:
|
||||
newitems.append(item)
|
||||
|
||||
if py2k:
|
||||
for item in newitems:
|
||||
reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
|
||||
|
||||
# seems like the functions attached to a test class aren't sorted already?
|
||||
# is that true and why's that? (when using unittest, they're sorted)
|
||||
items[:] = sorted(
|
||||
newitems,
|
||||
key=lambda item: (
|
||||
item.getparent(pytest.Module).name,
|
||||
item.getparent(pytest.Class).name,
|
||||
item.name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def pytest_pycollect_makeitem(collector, name, obj):
|
||||
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
if config.any_async:
|
||||
obj = _apply_maybe_async(obj)
|
||||
|
||||
ctor = getattr(pytest.Class, "from_parent", pytest.Class)
|
||||
return [
|
||||
ctor(name=parametrize_cls.__name__, parent=collector)
|
||||
for parametrize_cls in _parametrize_cls(collector.module, obj)
|
||||
]
|
||||
elif (
|
||||
inspect.isfunction(obj)
|
||||
and collector.cls is not None
|
||||
and plugin_base.want_method(collector.cls, obj)
|
||||
):
|
||||
# None means, fall back to default logic, which includes
|
||||
# method-level parametrize
|
||||
return None
|
||||
else:
|
||||
# empty list means skip this item
|
||||
return []
|
||||
|
||||
|
||||
def _is_wrapped_coroutine_function(fn):
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
|
||||
return inspect.iscoroutinefunction(fn)
|
||||
|
||||
|
||||
def _apply_maybe_async(obj, recurse=True):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
for name, value in vars(obj).items():
|
||||
if (
|
||||
(callable(value) or isinstance(value, classmethod))
|
||||
and not getattr(value, "_maybe_async_applied", False)
|
||||
and (name.startswith("test_"))
|
||||
and not _is_wrapped_coroutine_function(value)
|
||||
):
|
||||
is_classmethod = False
|
||||
if isinstance(value, classmethod):
|
||||
value = value.__func__
|
||||
is_classmethod = True
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def make_async(fn, *args, **kwargs):
|
||||
return asyncio._maybe_async(fn, *args, **kwargs)
|
||||
|
||||
do_async = make_async(value)
|
||||
if is_classmethod:
|
||||
do_async = classmethod(do_async)
|
||||
do_async._maybe_async_applied = True
|
||||
|
||||
setattr(obj, name, do_async)
|
||||
if recurse:
|
||||
for cls in obj.mro()[1:]:
|
||||
if cls != object:
|
||||
_apply_maybe_async(cls, False)
|
||||
return obj
|
||||
|
||||
|
||||
def _parametrize_cls(module, cls):
|
||||
"""implement a class-based version of pytest parametrize."""
|
||||
|
||||
if "_sa_parametrize" not in cls.__dict__:
|
||||
return [cls]
|
||||
|
||||
_sa_parametrize = cls._sa_parametrize
|
||||
classes = []
|
||||
for full_param_set in itertools.product(
|
||||
*[params for argname, params in _sa_parametrize]
|
||||
):
|
||||
cls_variables = {}
|
||||
|
||||
for argname, param in zip(
|
||||
[_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
|
||||
):
|
||||
if not argname:
|
||||
raise TypeError("need argnames for class-based combinations")
|
||||
argname_split = re.split(r",\s*", argname)
|
||||
for arg, val in zip(argname_split, param.values):
|
||||
cls_variables[arg] = val
|
||||
parametrized_name = "_".join(
|
||||
# token is a string, but in py2k pytest is giving us a unicode,
|
||||
# so call str() on it.
|
||||
str(re.sub(r"\W", "", token))
|
||||
for param in full_param_set
|
||||
for token in param.id.split("-")
|
||||
)
|
||||
name = "%s_%s" % (cls.__name__, parametrized_name)
|
||||
newcls = type.__new__(type, name, (cls,), cls_variables)
|
||||
setattr(module, name, newcls)
|
||||
classes.append(newcls)
|
||||
return classes
|
||||
|
||||
|
||||
_current_class = None
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# pytest_runtest_setup runs *before* pytest fixtures with scope="class".
|
||||
# plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
|
||||
# for the whole class and has to run things that are across all current
|
||||
# databases, so we run this outside of the pytest fixture system altogether
|
||||
# and ensure asyncio greenlet if any engines are async
|
||||
|
||||
global _current_class
|
||||
|
||||
if isinstance(item, pytest.Function) and _current_class is None:
|
||||
asyncio._maybe_async_provisioning(
|
||||
plugin_base.start_test_class_outside_fixtures,
|
||||
item.cls,
|
||||
)
|
||||
_current_class = item.getparent(pytest.Class)
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_teardown(item, nextitem):
|
||||
# runs inside of pytest function fixture scope
|
||||
# after test function runs
|
||||
from sqlalchemy.testing import asyncio
|
||||
from sqlalchemy.util import string_types
|
||||
|
||||
asyncio._maybe_async(plugin_base.after_test, item)
|
||||
|
||||
yield
|
||||
# this is now after all the fixture teardown have run, the class can be
|
||||
# finalized. Since pytest v7 this finalizer can no longer be added in
|
||||
# pytest_runtest_setup since the class has not yet been setup at that
|
||||
# time.
|
||||
# See https://github.com/pytest-dev/pytest/issues/9343
|
||||
global _current_class, _current_report
|
||||
|
||||
if _current_class is not None and (
|
||||
# last test or a new class
|
||||
nextitem is None
|
||||
or nextitem.getparent(pytest.Class) is not _current_class
|
||||
):
|
||||
_current_class = None
|
||||
|
||||
try:
|
||||
asyncio._maybe_async_provisioning(
|
||||
plugin_base.stop_test_class_outside_fixtures, item.cls
|
||||
)
|
||||
except Exception as e:
|
||||
# in case of an exception during teardown attach the original
|
||||
# error to the exception message, otherwise it will get lost
|
||||
if _current_report.failed:
|
||||
if not e.args:
|
||||
e.args = (
|
||||
"__Original test failure__:\n"
|
||||
+ _current_report.longreprtext,
|
||||
)
|
||||
elif e.args[-1] and isinstance(e.args[-1], string_types):
|
||||
args = list(e.args)
|
||||
args[-1] += (
|
||||
"\n__Original test failure__:\n"
|
||||
+ _current_report.longreprtext
|
||||
)
|
||||
e.args = tuple(args)
|
||||
else:
|
||||
e.args += (
|
||||
"__Original test failure__",
|
||||
_current_report.longreprtext,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_current_report = None
|
||||
|
||||
|
||||
def pytest_runtest_call(item):
|
||||
# runs inside of pytest function fixture scope
|
||||
# before test function runs
|
||||
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
asyncio._maybe_async(
|
||||
plugin_base.before_test,
|
||||
item,
|
||||
item.module.__name__,
|
||||
item.cls,
|
||||
item.name,
|
||||
)
|
||||
|
||||
|
||||
_current_report = None
|
||||
|
||||
|
||||
def pytest_runtest_logreport(report):
|
||||
global _current_report
|
||||
if report.when == "call":
|
||||
_current_report = report
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def setup_class_methods(request):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
cls = request.cls
|
||||
|
||||
if hasattr(cls, "setup_test_class"):
|
||||
asyncio._maybe_async(cls.setup_test_class)
|
||||
|
||||
if py2k:
|
||||
reinvent_fixtures_py2k.run_class_fixture_setup(request)
|
||||
|
||||
yield
|
||||
|
||||
if py2k:
|
||||
reinvent_fixtures_py2k.run_class_fixture_teardown(request)
|
||||
|
||||
if hasattr(cls, "teardown_test_class"):
|
||||
asyncio._maybe_async(cls.teardown_test_class)
|
||||
|
||||
asyncio._maybe_async(plugin_base.stop_test_class, cls)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def setup_test_methods(request):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# called for each test
|
||||
|
||||
self = request.instance
|
||||
|
||||
# before this fixture runs:
|
||||
|
||||
# 1. function level "autouse" fixtures under py3k (examples: TablesTest
|
||||
# define tables / data, MappedTest define tables / mappers / data)
|
||||
|
||||
# 2. run homegrown function level "autouse" fixtures under py2k
|
||||
if py2k:
|
||||
reinvent_fixtures_py2k.run_fn_fixture_setup(request)
|
||||
|
||||
# 3. run outer xdist-style setup
|
||||
if hasattr(self, "setup_test"):
|
||||
asyncio._maybe_async(self.setup_test)
|
||||
|
||||
# alembic test suite is using setUp and tearDown
|
||||
# xdist methods; support these in the test suite
|
||||
# for the near term
|
||||
if hasattr(self, "setUp"):
|
||||
asyncio._maybe_async(self.setUp)
|
||||
|
||||
# inside the yield:
|
||||
# 4. function level fixtures defined on test functions themselves,
|
||||
# e.g. "connection", "metadata" run next
|
||||
|
||||
# 5. pytest hook pytest_runtest_call then runs
|
||||
|
||||
# 6. test itself runs
|
||||
|
||||
yield
|
||||
|
||||
# yield finishes:
|
||||
|
||||
# 7. function level fixtures defined on test functions
|
||||
# themselves, e.g. "connection" rolls back the transaction, "metadata"
|
||||
# emits drop all
|
||||
|
||||
# 8. pytest hook pytest_runtest_teardown hook runs, this is associated
|
||||
# with fixtures close all sessions, provisioning.stop_test_class(),
|
||||
# engines.testing_reaper -> ensure all connection pool connections
|
||||
# are returned, engines created by testing_engine that aren't the
|
||||
# config engine are disposed
|
||||
|
||||
asyncio._maybe_async(plugin_base.after_test_fixtures, self)
|
||||
|
||||
# 10. run xdist-style teardown
|
||||
if hasattr(self, "tearDown"):
|
||||
asyncio._maybe_async(self.tearDown)
|
||||
|
||||
if hasattr(self, "teardown_test"):
|
||||
asyncio._maybe_async(self.teardown_test)
|
||||
|
||||
# 11. run homegrown function-level "autouse" fixtures under py2k
|
||||
if py2k:
|
||||
reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
|
||||
|
||||
# 12. function level "autouse" fixtures under py3k (examples: TablesTest /
|
||||
# MappedTest delete table data, possibly drop tables and clear mappers
|
||||
# depending on the flags defined by the test class)
|
||||
|
||||
|
||||
def getargspec(fn):
|
||||
if sys.version_info.major == 3:
|
||||
return inspect.getfullargspec(fn)
|
||||
else:
|
||||
return inspect.getargspec(fn)
|
||||
|
||||
|
||||
def _pytest_fn_decorator(target):
|
||||
"""Port of langhelpers.decorator with pytest-specific tricks."""
|
||||
|
||||
from sqlalchemy.util.langhelpers import format_argspec_plus
|
||||
from sqlalchemy.util.compat import inspect_getfullargspec
|
||||
|
||||
def _exec_code_in_env(code, env, fn_name):
|
||||
exec(code, env)
|
||||
return env[fn_name]
|
||||
|
||||
def decorate(fn, add_positional_parameters=()):
|
||||
|
||||
spec = inspect_getfullargspec(fn)
|
||||
if add_positional_parameters:
|
||||
spec.args.extend(add_positional_parameters)
|
||||
|
||||
metadata = dict(
|
||||
__target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
|
||||
)
|
||||
metadata.update(format_argspec_plus(spec, grouped=False))
|
||||
code = (
|
||||
"""\
|
||||
def %(name)s(%(args)s):
|
||||
return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
|
||||
"""
|
||||
% metadata
|
||||
)
|
||||
decorated = _exec_code_in_env(
|
||||
code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
|
||||
)
|
||||
if not add_positional_parameters:
|
||||
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
|
||||
decorated.__wrapped__ = fn
|
||||
return update_wrapper(decorated, fn)
|
||||
else:
|
||||
# this is the pytest hacky part. don't do a full update wrapper
|
||||
# because pytest is really being sneaky about finding the args
|
||||
# for the wrapped function
|
||||
decorated.__module__ = fn.__module__
|
||||
decorated.__name__ = fn.__name__
|
||||
if hasattr(fn, "pytestmark"):
|
||||
decorated.pytestmark = fn.pytestmark
|
||||
return decorated
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
class PytestFixtureFunctions(plugin_base.FixtureFunctions):
|
||||
def skip_test_exception(self, *arg, **kw):
|
||||
return pytest.skip.Exception(*arg, **kw)
|
||||
|
||||
def mark_base_test_class(self):
|
||||
return pytest.mark.usefixtures(
|
||||
"setup_class_methods", "setup_test_methods"
|
||||
)
|
||||
|
||||
_combination_id_fns = {
|
||||
"i": lambda obj: obj,
|
||||
"r": repr,
|
||||
"s": str,
|
||||
"n": lambda obj: obj.__name__
|
||||
if hasattr(obj, "__name__")
|
||||
else type(obj).__name__,
|
||||
}
|
||||
|
||||
def combinations(self, *arg_sets, **kw):
|
||||
"""Facade for pytest.mark.parametrize.
|
||||
|
||||
Automatically derives argument names from the callable which in our
|
||||
case is always a method on a class with positional arguments.
|
||||
|
||||
ids for parameter sets are derived using an optional template.
|
||||
|
||||
"""
|
||||
from sqlalchemy.testing import exclusions
|
||||
|
||||
if sys.version_info.major == 3:
|
||||
if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
|
||||
arg_sets = list(arg_sets[0])
|
||||
else:
|
||||
if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
|
||||
arg_sets = list(arg_sets[0])
|
||||
|
||||
argnames = kw.pop("argnames", None)
|
||||
|
||||
def _filter_exclusions(args):
|
||||
result = []
|
||||
gathered_exclusions = []
|
||||
for a in args:
|
||||
if isinstance(a, exclusions.compound):
|
||||
gathered_exclusions.append(a)
|
||||
else:
|
||||
result.append(a)
|
||||
|
||||
return result, gathered_exclusions
|
||||
|
||||
id_ = kw.pop("id_", None)
|
||||
|
||||
tobuild_pytest_params = []
|
||||
has_exclusions = False
|
||||
if id_:
|
||||
_combination_id_fns = self._combination_id_fns
|
||||
|
||||
# because itemgetter is not consistent for one argument vs.
|
||||
# multiple, make it multiple in all cases and use a slice
|
||||
# to omit the first argument
|
||||
_arg_getter = operator.itemgetter(
|
||||
0,
|
||||
*[
|
||||
idx
|
||||
for idx, char in enumerate(id_)
|
||||
if char in ("n", "r", "s", "a")
|
||||
]
|
||||
)
|
||||
fns = [
|
||||
(operator.itemgetter(idx), _combination_id_fns[char])
|
||||
for idx, char in enumerate(id_)
|
||||
if char in _combination_id_fns
|
||||
]
|
||||
|
||||
for arg in arg_sets:
|
||||
if not isinstance(arg, tuple):
|
||||
arg = (arg,)
|
||||
|
||||
fn_params, param_exclusions = _filter_exclusions(arg)
|
||||
|
||||
parameters = _arg_getter(fn_params)[1:]
|
||||
|
||||
if param_exclusions:
|
||||
has_exclusions = True
|
||||
|
||||
tobuild_pytest_params.append(
|
||||
(
|
||||
parameters,
|
||||
param_exclusions,
|
||||
"-".join(
|
||||
comb_fn(getter(arg)) for getter, comb_fn in fns
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
for arg in arg_sets:
|
||||
if not isinstance(arg, tuple):
|
||||
arg = (arg,)
|
||||
|
||||
fn_params, param_exclusions = _filter_exclusions(arg)
|
||||
|
||||
if param_exclusions:
|
||||
has_exclusions = True
|
||||
|
||||
tobuild_pytest_params.append(
|
||||
(fn_params, param_exclusions, None)
|
||||
)
|
||||
|
||||
pytest_params = []
|
||||
for parameters, param_exclusions, id_ in tobuild_pytest_params:
|
||||
if has_exclusions:
|
||||
parameters += (param_exclusions,)
|
||||
|
||||
param = pytest.param(*parameters, id=id_)
|
||||
pytest_params.append(param)
|
||||
|
||||
def decorate(fn):
|
||||
if inspect.isclass(fn):
|
||||
if has_exclusions:
|
||||
raise NotImplementedError(
|
||||
"exclusions not supported for class level combinations"
|
||||
)
|
||||
if "_sa_parametrize" not in fn.__dict__:
|
||||
fn._sa_parametrize = []
|
||||
fn._sa_parametrize.append((argnames, pytest_params))
|
||||
return fn
|
||||
else:
|
||||
if argnames is None:
|
||||
_argnames = getargspec(fn).args[1:]
|
||||
else:
|
||||
_argnames = re.split(r", *", argnames)
|
||||
|
||||
if has_exclusions:
|
||||
_argnames += ["_exclusions"]
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def check_exclusions(fn, *args, **kw):
|
||||
_exclusions = args[-1]
|
||||
if _exclusions:
|
||||
exlu = exclusions.compound().add(*_exclusions)
|
||||
fn = exlu(fn)
|
||||
return fn(*args[0:-1], **kw)
|
||||
|
||||
def process_metadata(spec):
|
||||
spec.args.append("_exclusions")
|
||||
|
||||
fn = check_exclusions(
|
||||
fn, add_positional_parameters=("_exclusions",)
|
||||
)
|
||||
|
||||
return pytest.mark.parametrize(_argnames, pytest_params)(fn)
|
||||
|
||||
return decorate
|
||||
|
||||
def param_ident(self, *parameters):
|
||||
ident = parameters[0]
|
||||
return pytest.param(*parameters[1:], id=ident)
|
||||
|
||||
def fixture(self, *arg, **kw):
|
||||
from sqlalchemy.testing import config
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
# wrapping pytest.fixture function. determine if
|
||||
# decorator was called as @fixture or @fixture().
|
||||
if len(arg) > 0 and callable(arg[0]):
|
||||
# was called as @fixture(), we have the function to wrap.
|
||||
fn = arg[0]
|
||||
arg = arg[1:]
|
||||
else:
|
||||
# was called as @fixture, don't have the function yet.
|
||||
fn = None
|
||||
|
||||
# create a pytest.fixture marker. because the fn is not being
|
||||
# passed, this is always a pytest.FixtureFunctionMarker()
|
||||
# object (or whatever pytest is calling it when you read this)
|
||||
# that is waiting for a function.
|
||||
fixture = pytest.fixture(*arg, **kw)
|
||||
|
||||
# now apply wrappers to the function, including fixture itself
|
||||
|
||||
def wrap(fn):
|
||||
if config.any_async:
|
||||
fn = asyncio._maybe_async_wrapper(fn)
|
||||
# other wrappers may be added here
|
||||
|
||||
if py2k and "autouse" in kw:
|
||||
# py2k workaround for too-slow collection of autouse fixtures
|
||||
# in pytest 4.6.11. See notes in reinvent_fixtures_py2k for
|
||||
# rationale.
|
||||
|
||||
# comment this condition out in order to disable the
|
||||
# py2k workaround entirely.
|
||||
reinvent_fixtures_py2k.add_fixture(fn, fixture)
|
||||
else:
|
||||
# now apply FixtureFunctionMarker
|
||||
fn = fixture(fn)
|
||||
|
||||
return fn
|
||||
|
||||
if fn:
|
||||
return wrap(fn)
|
||||
else:
|
||||
return wrap
|
||||
|
||||
def get_current_test_name(self):
|
||||
return os.environ.get("PYTEST_CURRENT_TEST")
|
||||
|
||||
def async_test(self, fn):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def decorate(fn, *args, **kwargs):
|
||||
asyncio._run_coroutine_function(fn, *args, **kwargs)
|
||||
|
||||
return decorate(fn)
|
||||
112
lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
Normal file
112
lib/sqlalchemy/testing/plugin/reinvent_fixtures_py2k.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
invent a quick version of pytest autouse fixtures as pytest's unacceptably slow
|
||||
collection/high memory use in pytest 4.6.11, which is the highest version that
|
||||
works in py2k.
|
||||
|
||||
by "too-slow" we mean the test suite can't even manage to be collected for a
|
||||
single process in less than 70 seconds or so and memory use seems to be very
|
||||
high as well. for two or four workers the job just times out after ten
|
||||
minutes.
|
||||
|
||||
so instead we have invented a very limited form of these fixtures, as our
|
||||
current use of "autouse" fixtures are limited to those in fixtures.py.
|
||||
|
||||
assumptions for these fixtures:
|
||||
|
||||
1. we are only using "function" or "class" scope
|
||||
|
||||
2. the functions must be associated with a test class
|
||||
|
||||
3. the fixture functions cannot themselves use pytest fixtures
|
||||
|
||||
4. the fixture functions must use yield, not return
|
||||
|
||||
When py2k support is removed and we can stay on a modern pytest version, this
|
||||
can all be removed.
|
||||
|
||||
|
||||
"""
|
||||
import collections
|
||||
|
||||
|
||||
_py2k_fixture_fn_names = collections.defaultdict(set)
|
||||
_py2k_class_fixtures = collections.defaultdict(
|
||||
lambda: collections.defaultdict(set)
|
||||
)
|
||||
_py2k_function_fixtures = collections.defaultdict(
|
||||
lambda: collections.defaultdict(set)
|
||||
)
|
||||
|
||||
_py2k_cls_fixture_stack = []
|
||||
_py2k_fn_fixture_stack = []
|
||||
|
||||
|
||||
def add_fixture(fn, fixture):
|
||||
assert fixture.scope in ("class", "function")
|
||||
_py2k_fixture_fn_names[fn.__name__].add((fn, fixture.scope))
|
||||
|
||||
|
||||
def scan_for_fixtures_to_use_for_class(item):
|
||||
test_class = item.parent.parent.obj
|
||||
|
||||
for name in _py2k_fixture_fn_names:
|
||||
for fixture_fn, scope in _py2k_fixture_fn_names[name]:
|
||||
meth = getattr(test_class, name, None)
|
||||
if meth and meth.im_func is fixture_fn:
|
||||
for sup in test_class.__mro__:
|
||||
if name in sup.__dict__:
|
||||
if scope == "class":
|
||||
_py2k_class_fixtures[test_class][sup].add(meth)
|
||||
elif scope == "function":
|
||||
_py2k_function_fixtures[test_class][sup].add(meth)
|
||||
break
|
||||
break
|
||||
|
||||
|
||||
def run_class_fixture_setup(request):
|
||||
|
||||
cls = request.cls
|
||||
self = cls.__new__(cls)
|
||||
|
||||
fixtures_for_this_class = _py2k_class_fixtures.get(cls)
|
||||
|
||||
if fixtures_for_this_class:
|
||||
for sup_ in cls.__mro__:
|
||||
for fn in fixtures_for_this_class.get(sup_, ()):
|
||||
iter_ = fn(self)
|
||||
next(iter_)
|
||||
|
||||
_py2k_cls_fixture_stack.append(iter_)
|
||||
|
||||
|
||||
def run_class_fixture_teardown(request):
|
||||
while _py2k_cls_fixture_stack:
|
||||
iter_ = _py2k_cls_fixture_stack.pop(-1)
|
||||
try:
|
||||
next(iter_)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
|
||||
def run_fn_fixture_setup(request):
|
||||
cls = request.cls
|
||||
self = request.instance
|
||||
|
||||
fixtures_for_this_class = _py2k_function_fixtures.get(cls)
|
||||
|
||||
if fixtures_for_this_class:
|
||||
for sup_ in reversed(cls.__mro__):
|
||||
for fn in fixtures_for_this_class.get(sup_, ()):
|
||||
iter_ = fn(self)
|
||||
next(iter_)
|
||||
|
||||
_py2k_fn_fixture_stack.append(iter_)
|
||||
|
||||
|
||||
def run_fn_fixture_teardown(request):
|
||||
while _py2k_fn_fixture_stack:
|
||||
iter_ = _py2k_fn_fixture_stack.pop(-1)
|
||||
try:
|
||||
next(iter_)
|
||||
except StopIteration:
|
||||
pass
|
||||
335
lib/sqlalchemy/testing/profiling.py
Normal file
335
lib/sqlalchemy/testing/profiling.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# testing/profiling.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
"""Profiling support for unit and performance tests.
|
||||
|
||||
These are special purpose profiling methods which operate
|
||||
in a more fine-grained way than nose's profiling plugin.
|
||||
|
||||
"""
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import os
|
||||
import platform
|
||||
import pstats
|
||||
import re
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from .util import gc_collect
|
||||
from ..util import has_compiled_ext
|
||||
|
||||
|
||||
try:
|
||||
import cProfile
|
||||
except ImportError:
|
||||
cProfile = None
|
||||
|
||||
_profile_stats = None
|
||||
"""global ProfileStatsFileInstance.
|
||||
|
||||
plugin_base assigns this at the start of all tests.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
_current_test = None
|
||||
"""String id of current test.
|
||||
|
||||
plugin_base assigns this at the start of each test using
|
||||
_start_current_test.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def _start_current_test(id_):
|
||||
global _current_test
|
||||
_current_test = id_
|
||||
|
||||
if _profile_stats.force_write:
|
||||
_profile_stats.reset_count()
|
||||
|
||||
|
||||
class ProfileStatsFile(object):
|
||||
"""Store per-platform/fn profiling results in a file.
|
||||
|
||||
There was no json module available when this was written, but now
|
||||
the file format which is very deterministically line oriented is kind of
|
||||
handy in any case for diffs and merges.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, filename, sort="cumulative", dump=None):
|
||||
self.force_write = (
|
||||
config.options is not None and config.options.force_write_profiles
|
||||
)
|
||||
self.write = self.force_write or (
|
||||
config.options is not None and config.options.write_profiles
|
||||
)
|
||||
self.fname = os.path.abspath(filename)
|
||||
self.short_fname = os.path.split(self.fname)[-1]
|
||||
self.data = collections.defaultdict(
|
||||
lambda: collections.defaultdict(dict)
|
||||
)
|
||||
self.dump = dump
|
||||
self.sort = sort
|
||||
self._read()
|
||||
if self.write:
|
||||
# rewrite for the case where features changed,
|
||||
# etc.
|
||||
self._write()
|
||||
|
||||
@property
|
||||
def platform_key(self):
|
||||
|
||||
dbapi_key = config.db.name + "_" + config.db.driver
|
||||
|
||||
if config.db.name == "sqlite" and config.db.dialect._is_url_file_db(
|
||||
config.db.url
|
||||
):
|
||||
dbapi_key += "_file"
|
||||
|
||||
# keep it at 2.7, 3.1, 3.2, etc. for now.
|
||||
py_version = ".".join([str(v) for v in sys.version_info[0:2]])
|
||||
|
||||
platform_tokens = [
|
||||
platform.machine(),
|
||||
platform.system().lower(),
|
||||
platform.python_implementation().lower(),
|
||||
py_version,
|
||||
dbapi_key,
|
||||
]
|
||||
|
||||
platform_tokens.append(
|
||||
"nativeunicode"
|
||||
if config.db.dialect.convert_unicode
|
||||
else "dbapiunicode"
|
||||
)
|
||||
_has_cext = has_compiled_ext()
|
||||
platform_tokens.append(_has_cext and "cextensions" or "nocextensions")
|
||||
return "_".join(platform_tokens)
|
||||
|
||||
def has_stats(self):
|
||||
test_key = _current_test
|
||||
return (
|
||||
test_key in self.data and self.platform_key in self.data[test_key]
|
||||
)
|
||||
|
||||
def result(self, callcount):
|
||||
test_key = _current_test
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
|
||||
if "counts" not in per_platform:
|
||||
per_platform["counts"] = counts = []
|
||||
else:
|
||||
counts = per_platform["counts"]
|
||||
|
||||
if "current_count" not in per_platform:
|
||||
per_platform["current_count"] = current_count = 0
|
||||
else:
|
||||
current_count = per_platform["current_count"]
|
||||
|
||||
has_count = len(counts) > current_count
|
||||
|
||||
if not has_count:
|
||||
counts.append(callcount)
|
||||
if self.write:
|
||||
self._write()
|
||||
result = None
|
||||
else:
|
||||
result = per_platform["lineno"], counts[current_count]
|
||||
per_platform["current_count"] += 1
|
||||
return result
|
||||
|
||||
def reset_count(self):
|
||||
test_key = _current_test
|
||||
# since self.data is a defaultdict, don't access a key
|
||||
# if we don't know it's there first.
|
||||
if test_key not in self.data:
|
||||
return
|
||||
per_fn = self.data[test_key]
|
||||
if self.platform_key not in per_fn:
|
||||
return
|
||||
per_platform = per_fn[self.platform_key]
|
||||
if "counts" in per_platform:
|
||||
per_platform["counts"][:] = []
|
||||
|
||||
def replace(self, callcount):
|
||||
test_key = _current_test
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[self.platform_key]
|
||||
counts = per_platform["counts"]
|
||||
current_count = per_platform["current_count"]
|
||||
if current_count < len(counts):
|
||||
counts[current_count - 1] = callcount
|
||||
else:
|
||||
counts[-1] = callcount
|
||||
if self.write:
|
||||
self._write()
|
||||
|
||||
def _header(self):
|
||||
return (
|
||||
"# %s\n"
|
||||
"# This file is written out on a per-environment basis.\n"
|
||||
"# For each test in aaa_profiling, the corresponding "
|
||||
"function and \n"
|
||||
"# environment is located within this file. "
|
||||
"If it doesn't exist,\n"
|
||||
"# the test is skipped.\n"
|
||||
"# If a callcount does exist, it is compared "
|
||||
"to what we received. \n"
|
||||
"# assertions are raised if the counts do not match.\n"
|
||||
"# \n"
|
||||
"# To add a new callcount test, apply the function_call_count \n"
|
||||
"# decorator and re-run the tests using the --write-profiles \n"
|
||||
"# option - this file will be rewritten including the new count.\n"
|
||||
"# \n"
|
||||
) % (self.fname)
|
||||
|
||||
def _read(self):
|
||||
try:
|
||||
profile_f = open(self.fname)
|
||||
except IOError:
|
||||
return
|
||||
for lineno, line in enumerate(profile_f):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
test_key, platform_key, counts = line.split()
|
||||
per_fn = self.data[test_key]
|
||||
per_platform = per_fn[platform_key]
|
||||
c = [int(count) for count in counts.split(",")]
|
||||
per_platform["counts"] = c
|
||||
per_platform["lineno"] = lineno + 1
|
||||
per_platform["current_count"] = 0
|
||||
profile_f.close()
|
||||
|
||||
def _write(self):
|
||||
print(("Writing profile file %s" % self.fname))
|
||||
profile_f = open(self.fname, "w")
|
||||
profile_f.write(self._header())
|
||||
for test_key in sorted(self.data):
|
||||
|
||||
per_fn = self.data[test_key]
|
||||
profile_f.write("\n# TEST: %s\n\n" % test_key)
|
||||
for platform_key in sorted(per_fn):
|
||||
per_platform = per_fn[platform_key]
|
||||
c = ",".join(str(count) for count in per_platform["counts"])
|
||||
profile_f.write("%s %s %s\n" % (test_key, platform_key, c))
|
||||
profile_f.close()
|
||||
|
||||
|
||||
def function_call_count(variance=0.05, times=1, warmup=0):
|
||||
"""Assert a target for a test case's function call count.
|
||||
|
||||
The main purpose of this assertion is to detect changes in
|
||||
callcounts for various functions - the actual number is not as important.
|
||||
Callcounts are stored in a file keyed to Python version and OS platform
|
||||
information. This file is generated automatically for new tests,
|
||||
and versioned so that unexpected changes in callcounts will be detected.
|
||||
|
||||
"""
|
||||
|
||||
# use signature-rewriting decorator function so that pytest fixtures
|
||||
# still work on py27. In Py3, update_wrapper() alone is good enough,
|
||||
# likely due to the introduction of __signature__.
|
||||
|
||||
from sqlalchemy.util import decorator
|
||||
from sqlalchemy.util import deprecations
|
||||
from sqlalchemy.engine import row
|
||||
from sqlalchemy.testing import mock
|
||||
|
||||
@decorator
|
||||
def wrap(fn, *args, **kw):
|
||||
|
||||
with mock.patch.object(
|
||||
deprecations, "SQLALCHEMY_WARN_20", False
|
||||
), mock.patch.object(
|
||||
row.LegacyRow, "_default_key_style", row.KEY_OBJECTS_NO_WARN
|
||||
):
|
||||
for warm in range(warmup):
|
||||
fn(*args, **kw)
|
||||
|
||||
timerange = range(times)
|
||||
with count_functions(variance=variance):
|
||||
for time in timerange:
|
||||
rv = fn(*args, **kw)
|
||||
return rv
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def count_functions(variance=0.05):
|
||||
if cProfile is None:
|
||||
raise config._skip_test_exception("cProfile is not installed")
|
||||
|
||||
if not _profile_stats.has_stats() and not _profile_stats.write:
|
||||
config.skip_test(
|
||||
"No profiling stats available on this "
|
||||
"platform for this function. Run tests with "
|
||||
"--write-profiles to add statistics to %s for "
|
||||
"this platform." % _profile_stats.short_fname
|
||||
)
|
||||
|
||||
gc_collect()
|
||||
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
# began = time.time()
|
||||
yield
|
||||
# ended = time.time()
|
||||
pr.disable()
|
||||
|
||||
# s = compat.StringIO()
|
||||
stats = pstats.Stats(pr, stream=sys.stdout)
|
||||
|
||||
# timespent = ended - began
|
||||
callcount = stats.total_calls
|
||||
|
||||
expected = _profile_stats.result(callcount)
|
||||
|
||||
if expected is None:
|
||||
expected_count = None
|
||||
else:
|
||||
line_no, expected_count = expected
|
||||
|
||||
print(("Pstats calls: %d Expected %s" % (callcount, expected_count)))
|
||||
stats.sort_stats(*re.split(r"[, ]", _profile_stats.sort))
|
||||
stats.print_stats()
|
||||
if _profile_stats.dump:
|
||||
base, ext = os.path.splitext(_profile_stats.dump)
|
||||
test_name = _current_test.split(".")[-1]
|
||||
dumpfile = "%s_%s%s" % (base, test_name, ext or ".profile")
|
||||
stats.dump_stats(dumpfile)
|
||||
print("Dumped stats to file %s" % dumpfile)
|
||||
# stats.print_callers()
|
||||
if _profile_stats.force_write:
|
||||
_profile_stats.replace(callcount)
|
||||
elif expected_count:
|
||||
deviance = int(callcount * variance)
|
||||
failed = abs(callcount - expected_count) > deviance
|
||||
|
||||
if failed:
|
||||
if _profile_stats.write:
|
||||
_profile_stats.replace(callcount)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Adjusted function call count %s not within %s%% "
|
||||
"of expected %s, platform %s. Rerun with "
|
||||
"--write-profiles to "
|
||||
"regenerate this callcount."
|
||||
% (
|
||||
callcount,
|
||||
(variance * 100),
|
||||
expected_count,
|
||||
_profile_stats.platform_key,
|
||||
)
|
||||
)
|
||||
416
lib/sqlalchemy/testing/provision.py
Normal file
416
lib/sqlalchemy/testing/provision.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import collections
|
||||
import logging
|
||||
|
||||
from . import config
|
||||
from . import engines
|
||||
from . import util
|
||||
from .. import exc
|
||||
from .. import inspect
|
||||
from ..engine import url as sa_url
|
||||
from ..sql import ddl
|
||||
from ..sql import schema
|
||||
from ..util import compat
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
FOLLOWER_IDENT = None
|
||||
|
||||
|
||||
class register(object):
|
||||
def __init__(self):
|
||||
self.fns = {}
|
||||
|
||||
@classmethod
|
||||
def init(cls, fn):
|
||||
return register().for_db("*")(fn)
|
||||
|
||||
def for_db(self, *dbnames):
|
||||
def decorate(fn):
|
||||
for dbname in dbnames:
|
||||
self.fns[dbname] = fn
|
||||
return self
|
||||
|
||||
return decorate
|
||||
|
||||
def __call__(self, cfg, *arg):
|
||||
if isinstance(cfg, compat.string_types):
|
||||
url = sa_url.make_url(cfg)
|
||||
elif isinstance(cfg, sa_url.URL):
|
||||
url = cfg
|
||||
else:
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
if backend in self.fns:
|
||||
return self.fns[backend](cfg, *arg)
|
||||
else:
|
||||
return self.fns["*"](cfg, *arg)
|
||||
|
||||
|
||||
def create_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
|
||||
create_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def setup_config(db_url, options, file_config, follower_ident):
|
||||
# load the dialect, which should also have it set up its provision
|
||||
# hooks
|
||||
|
||||
dialect = sa_url.make_url(db_url).get_dialect()
|
||||
dialect.load_provisioning()
|
||||
|
||||
if follower_ident:
|
||||
db_url = follower_url_from_main(db_url, follower_ident)
|
||||
db_opts = {}
|
||||
update_db_opts(db_url, db_opts)
|
||||
db_opts["scope"] = "global"
|
||||
eng = engines.testing_engine(db_url, db_opts)
|
||||
post_configure_engine(db_url, eng, follower_ident)
|
||||
eng.connect().close()
|
||||
|
||||
cfg = config.Config.register(eng, db_opts, options, file_config)
|
||||
|
||||
# a symbolic name that tests can use if they need to disambiguate
|
||||
# names across databases
|
||||
if follower_ident:
|
||||
config.ident = follower_ident
|
||||
|
||||
if follower_ident:
|
||||
configure_follower(cfg, follower_ident)
|
||||
return cfg
|
||||
|
||||
|
||||
def drop_follower_db(follower_ident):
|
||||
for cfg in _configs_for_db_operation():
|
||||
log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
|
||||
drop_db(cfg, cfg.db, follower_ident)
|
||||
|
||||
|
||||
def generate_db_urls(db_urls, extra_drivers):
|
||||
"""Generate a set of URLs to test given configured URLs plus additional
|
||||
driver names.
|
||||
|
||||
Given::
|
||||
|
||||
--dburi postgresql://db1 \
|
||||
--dburi postgresql://db2 \
|
||||
--dburi postgresql://db2 \
|
||||
--dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
|
||||
|
||||
Noting that the default postgresql driver is psycopg2, the output
|
||||
would be::
|
||||
|
||||
postgresql+psycopg2://db1
|
||||
postgresql+asyncpg://db1
|
||||
postgresql+psycopg2://db2
|
||||
postgresql+psycopg2://db3
|
||||
|
||||
That is, for the driver in a --dburi, we want to keep that and use that
|
||||
driver for each URL it's part of . For a driver that is only
|
||||
in --dbdrivers, we want to use it just once for one of the URLs.
|
||||
for a driver that is both coming from --dburi as well as --dbdrivers,
|
||||
we want to keep it in that dburi.
|
||||
|
||||
Driver specific query options can be specified by added them to the
|
||||
driver name. For example, to enable the async fallback option for
|
||||
asyncpg::
|
||||
|
||||
--dburi postgresql://db1 \
|
||||
--dbdriver=asyncpg?async_fallback=true
|
||||
|
||||
"""
|
||||
urls = set()
|
||||
|
||||
backend_to_driver_we_already_have = collections.defaultdict(set)
|
||||
|
||||
urls_plus_dialects = [
|
||||
(url_obj, url_obj.get_dialect())
|
||||
for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
|
||||
]
|
||||
|
||||
for url_obj, dialect in urls_plus_dialects:
|
||||
backend_to_driver_we_already_have[dialect.name].add(dialect.driver)
|
||||
|
||||
backend_to_driver_we_need = {}
|
||||
|
||||
for url_obj, dialect in urls_plus_dialects:
|
||||
backend = dialect.name
|
||||
dialect.load_provisioning()
|
||||
|
||||
if backend not in backend_to_driver_we_need:
|
||||
backend_to_driver_we_need[backend] = extra_per_backend = set(
|
||||
extra_drivers
|
||||
).difference(backend_to_driver_we_already_have[backend])
|
||||
else:
|
||||
extra_per_backend = backend_to_driver_we_need[backend]
|
||||
|
||||
for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
|
||||
if driver_url in urls:
|
||||
continue
|
||||
urls.add(driver_url)
|
||||
yield driver_url
|
||||
|
||||
|
||||
def _generate_driver_urls(url, extra_drivers):
|
||||
main_driver = url.get_driver_name()
|
||||
extra_drivers.discard(main_driver)
|
||||
|
||||
url = generate_driver_url(url, main_driver, "")
|
||||
yield str(url)
|
||||
|
||||
for drv in list(extra_drivers):
|
||||
|
||||
if "?" in drv:
|
||||
|
||||
driver_only, query_str = drv.split("?", 1)
|
||||
|
||||
else:
|
||||
driver_only = drv
|
||||
query_str = None
|
||||
|
||||
new_url = generate_driver_url(url, driver_only, query_str)
|
||||
if new_url:
|
||||
extra_drivers.remove(drv)
|
||||
|
||||
yield str(new_url)
|
||||
|
||||
|
||||
@register.init
|
||||
def generate_driver_url(url, driver, query_str):
|
||||
backend = url.get_backend_name()
|
||||
|
||||
new_url = url.set(
|
||||
drivername="%s+%s" % (backend, driver),
|
||||
)
|
||||
if query_str:
|
||||
new_url = new_url.update_query_string(query_str)
|
||||
|
||||
try:
|
||||
new_url.get_dialect()
|
||||
except exc.NoSuchModuleError:
|
||||
return None
|
||||
else:
|
||||
return new_url
|
||||
|
||||
|
||||
def _configs_for_db_operation():
|
||||
hosts = set()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
url = cfg.db.url
|
||||
backend = url.get_backend_name()
|
||||
host_conf = (backend, url.username, url.host, url.database)
|
||||
|
||||
if host_conf not in hosts:
|
||||
yield cfg
|
||||
hosts.add(host_conf)
|
||||
|
||||
for cfg in config.Config.all_configs():
|
||||
cfg.db.dispose()
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_all_schema_objects_pre_tables(cfg, eng):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_all_schema_objects_post_tables(cfg, eng):
|
||||
pass
|
||||
|
||||
|
||||
def drop_all_schema_objects(cfg, eng):
|
||||
|
||||
drop_all_schema_objects_pre_tables(cfg, eng)
|
||||
|
||||
inspector = inspect(eng)
|
||||
try:
|
||||
view_names = inspector.get_view_names()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
with eng.begin() as conn:
|
||||
for vname in view_names:
|
||||
conn.execute(
|
||||
ddl._DropView(schema.Table(vname, schema.MetaData()))
|
||||
)
|
||||
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
try:
|
||||
view_names = inspector.get_view_names(schema="test_schema")
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
with eng.begin() as conn:
|
||||
for vname in view_names:
|
||||
conn.execute(
|
||||
ddl._DropView(
|
||||
schema.Table(
|
||||
vname,
|
||||
schema.MetaData(),
|
||||
schema="test_schema",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
util.drop_all_tables(eng, inspector)
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
|
||||
util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
|
||||
|
||||
drop_all_schema_objects_post_tables(cfg, eng)
|
||||
|
||||
if config.requirements.sequences.enabled_for_config(cfg):
|
||||
with eng.begin() as conn:
|
||||
for seq in inspector.get_sequence_names():
|
||||
conn.execute(ddl.DropSequence(schema.Sequence(seq)))
|
||||
if config.requirements.schemas.enabled_for_config(cfg):
|
||||
for schema_name in [cfg.test_schema, cfg.test_schema_2]:
|
||||
for seq in inspector.get_sequence_names(
|
||||
schema=schema_name
|
||||
):
|
||||
conn.execute(
|
||||
ddl.DropSequence(
|
||||
schema.Sequence(seq, schema=schema_name)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def create_db(cfg, eng, ident):
|
||||
"""Dynamically create a database for testing.
|
||||
|
||||
Used when a test run will employ multiple processes, e.g., when run
|
||||
via `tox` or `pytest -n4`.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"no DB creation routine for cfg: %s" % (eng.url,)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def drop_db(cfg, eng, ident):
|
||||
"""Drop a database that we dynamically created for testing."""
|
||||
raise NotImplementedError("no DB drop routine for cfg: %s" % (eng.url,))
|
||||
|
||||
|
||||
@register.init
|
||||
def update_db_opts(db_url, db_opts):
|
||||
"""Set database options (db_opts) for a test database that we created."""
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def post_configure_engine(url, engine, follower_ident):
|
||||
"""Perform extra steps after configuring an engine for testing.
|
||||
|
||||
(For the internal dialects, currently only used by sqlite, oracle)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def follower_url_from_main(url, ident):
|
||||
"""Create a connection URL for a dynamically-created test database.
|
||||
|
||||
:param url: the connection URL specified when the test run was invoked
|
||||
:param ident: the pytest-xdist "worker identifier" to be used as the
|
||||
database name
|
||||
"""
|
||||
url = sa_url.make_url(url)
|
||||
return url.set(database=ident)
|
||||
|
||||
|
||||
@register.init
|
||||
def configure_follower(cfg, ident):
|
||||
"""Create dialect-specific config settings for a follower database."""
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def run_reap_dbs(url, ident):
|
||||
"""Remove databases that were created during the test process, after the
|
||||
process has ended.
|
||||
|
||||
This is an optional step that is invoked for certain backends that do not
|
||||
reliably release locks on the database as long as a process is still in
|
||||
use. For the internal dialects, this is currently only necessary for
|
||||
mssql and oracle.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def reap_dbs(idents_file):
|
||||
log.info("Reaping databases...")
|
||||
|
||||
urls = collections.defaultdict(set)
|
||||
idents = collections.defaultdict(set)
|
||||
dialects = {}
|
||||
|
||||
with open(idents_file) as file_:
|
||||
for line in file_:
|
||||
line = line.strip()
|
||||
db_name, db_url = line.split(" ")
|
||||
url_obj = sa_url.make_url(db_url)
|
||||
if db_name not in dialects:
|
||||
dialects[db_name] = url_obj.get_dialect()
|
||||
dialects[db_name].load_provisioning()
|
||||
url_key = (url_obj.get_backend_name(), url_obj.host)
|
||||
urls[url_key].add(db_url)
|
||||
idents[url_key].add(db_name)
|
||||
|
||||
for url_key in urls:
|
||||
url = list(urls[url_key])[0]
|
||||
ident = idents[url_key]
|
||||
run_reap_dbs(url, ident)
|
||||
|
||||
|
||||
@register.init
|
||||
def temp_table_keyword_args(cfg, eng):
|
||||
"""Specify keyword arguments for creating a temporary Table.
|
||||
|
||||
Dialect-specific implementations of this method will return the
|
||||
kwargs that are passed to the Table method when creating a temporary
|
||||
table for testing, e.g., in the define_temp_tables method of the
|
||||
ComponentReflectionTest class in suite/test_reflection.py
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"no temp table keyword args routine for cfg: %s" % (eng.url,)
|
||||
)
|
||||
|
||||
|
||||
@register.init
|
||||
def prepare_for_drop_tables(config, connection):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def stop_test_class_outside_fixtures(config, db, testcls):
|
||||
pass
|
||||
|
||||
|
||||
@register.init
|
||||
def get_temp_table_name(cfg, eng, base_name):
|
||||
"""Specify table name for creating a temporary Table.
|
||||
|
||||
Dialect-specific implementations of this method will return the
|
||||
name to use when creating a temporary table for testing,
|
||||
e.g., in the define_temp_tables method of the
|
||||
ComponentReflectionTest class in suite/test_reflection.py
|
||||
|
||||
Default to just the base name since that's what most dialects will
|
||||
use. The mssql dialect's implementation will need a "#" prepended.
|
||||
"""
|
||||
return base_name
|
||||
|
||||
|
||||
@register.init
|
||||
def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
|
||||
raise NotImplementedError(
|
||||
"backend does not implement a schema name set function: %s"
|
||||
% (cfg.db.url,)
|
||||
)
|
||||
1518
lib/sqlalchemy/testing/requirements.py
Normal file
1518
lib/sqlalchemy/testing/requirements.py
Normal file
File diff suppressed because it is too large
Load Diff
218
lib/sqlalchemy/testing/schema.py
Normal file
218
lib/sqlalchemy/testing/schema.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# testing/schema.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import sys
|
||||
|
||||
from . import config
|
||||
from . import exclusions
|
||||
from .. import event
|
||||
from .. import schema
|
||||
from .. import types as sqltypes
|
||||
from ..util import OrderedDict
|
||||
|
||||
|
||||
__all__ = ["Table", "Column"]
|
||||
|
||||
table_options = {}
|
||||
|
||||
|
||||
def Table(*args, **kw):
|
||||
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
|
||||
|
||||
kw.update(table_options)
|
||||
|
||||
if exclusions.against(config._current, "mysql"):
|
||||
if (
|
||||
"mysql_engine" not in kw
|
||||
and "mysql_type" not in kw
|
||||
and "autoload_with" not in kw
|
||||
):
|
||||
if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
|
||||
kw["mysql_engine"] = "InnoDB"
|
||||
else:
|
||||
kw["mysql_engine"] = "MyISAM"
|
||||
elif exclusions.against(config._current, "mariadb"):
|
||||
if (
|
||||
"mariadb_engine" not in kw
|
||||
and "mariadb_type" not in kw
|
||||
and "autoload_with" not in kw
|
||||
):
|
||||
if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts:
|
||||
kw["mariadb_engine"] = "InnoDB"
|
||||
else:
|
||||
kw["mariadb_engine"] = "MyISAM"
|
||||
|
||||
# Apply some default cascading rules for self-referential foreign keys.
|
||||
# MySQL InnoDB has some issues around selecting self-refs too.
|
||||
if exclusions.against(config._current, "firebird"):
|
||||
table_name = args[0]
|
||||
unpack = config.db.dialect.identifier_preparer.unformat_identifiers
|
||||
|
||||
# Only going after ForeignKeys in Columns. May need to
|
||||
# expand to ForeignKeyConstraint too.
|
||||
fks = [
|
||||
fk
|
||||
for col in args
|
||||
if isinstance(col, schema.Column)
|
||||
for fk in col.foreign_keys
|
||||
]
|
||||
|
||||
for fk in fks:
|
||||
# root around in raw spec
|
||||
ref = fk._colspec
|
||||
if isinstance(ref, schema.Column):
|
||||
name = ref.table.name
|
||||
else:
|
||||
# take just the table name: on FB there cannot be
|
||||
# a schema, so the first element is always the
|
||||
# table name, possibly followed by the field name
|
||||
name = unpack(ref)[0]
|
||||
if name == table_name:
|
||||
if fk.ondelete is None:
|
||||
fk.ondelete = "CASCADE"
|
||||
if fk.onupdate is None:
|
||||
fk.onupdate = "CASCADE"
|
||||
|
||||
return schema.Table(*args, **kw)
|
||||
|
||||
|
||||
def Column(*args, **kw):
|
||||
"""A schema.Column wrapper/hook for dialect-specific tweaks."""
|
||||
|
||||
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
|
||||
|
||||
if not config.requirements.foreign_key_ddl.enabled_for_config(config):
|
||||
args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
|
||||
|
||||
col = schema.Column(*args, **kw)
|
||||
if test_opts.get("test_needs_autoincrement", False) and kw.get(
|
||||
"primary_key", False
|
||||
):
|
||||
|
||||
if col.default is None and col.server_default is None:
|
||||
col.autoincrement = True
|
||||
|
||||
# allow any test suite to pick up on this
|
||||
col.info["test_needs_autoincrement"] = True
|
||||
|
||||
# hardcoded rule for firebird, oracle; this should
|
||||
# be moved out
|
||||
if exclusions.against(config._current, "firebird", "oracle"):
|
||||
|
||||
def add_seq(c, tbl):
|
||||
c._init_items(
|
||||
schema.Sequence(
|
||||
_truncate_name(
|
||||
config.db.dialect, tbl.name + "_" + c.name + "_seq"
|
||||
),
|
||||
optional=True,
|
||||
)
|
||||
)
|
||||
|
||||
event.listen(col, "after_parent_attach", add_seq, propagate=True)
|
||||
return col
|
||||
|
||||
|
||||
class eq_type_affinity(object):
|
||||
"""Helper to compare types inside of datastructures based on affinity.
|
||||
|
||||
E.g.::
|
||||
|
||||
eq_(
|
||||
inspect(connection).get_columns("foo"),
|
||||
[
|
||||
{
|
||||
"name": "id",
|
||||
"type": testing.eq_type_affinity(sqltypes.INTEGER),
|
||||
"nullable": False,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
},
|
||||
{
|
||||
"name": "data",
|
||||
"type": testing.eq_type_affinity(sqltypes.NullType),
|
||||
"nullable": True,
|
||||
"default": None,
|
||||
"autoincrement": False,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, target):
|
||||
self.target = sqltypes.to_instance(target)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.target._type_affinity is other._type_affinity
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.target._type_affinity is not other._type_affinity
|
||||
|
||||
|
||||
class eq_clause_element(object):
|
||||
"""Helper to compare SQL structures based on compare()"""
|
||||
|
||||
def __init__(self, target):
|
||||
self.target = target
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.target.compare(other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.target.compare(other)
|
||||
|
||||
|
||||
def _truncate_name(dialect, name):
|
||||
if len(name) > dialect.max_identifier_length:
|
||||
return (
|
||||
name[0 : max(dialect.max_identifier_length - 6, 0)]
|
||||
+ "_"
|
||||
+ hex(hash(name) % 64)[2:]
|
||||
)
|
||||
else:
|
||||
return name
|
||||
|
||||
|
||||
def pep435_enum(name):
|
||||
# Implements PEP 435 in the minimal fashion needed by SQLAlchemy
|
||||
__members__ = OrderedDict()
|
||||
|
||||
def __init__(self, name, value, alias=None):
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.__members__[name] = self
|
||||
value_to_member[value] = self
|
||||
setattr(self.__class__, name, self)
|
||||
if alias:
|
||||
self.__members__[alias] = self
|
||||
setattr(self.__class__, alias, self)
|
||||
|
||||
value_to_member = {}
|
||||
|
||||
@classmethod
|
||||
def get(cls, value):
|
||||
return value_to_member[value]
|
||||
|
||||
someenum = type(
|
||||
name,
|
||||
(object,),
|
||||
{"__members__": __members__, "__init__": __init__, "get": get},
|
||||
)
|
||||
|
||||
# getframe() trick for pickling I don't understand courtesy
|
||||
# Python namedtuple()
|
||||
try:
|
||||
module = sys._getframe(1).f_globals.get("__name__", "__main__")
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
if module is not None:
|
||||
someenum.__module__ = module
|
||||
|
||||
return someenum
|
||||
13
lib/sqlalchemy/testing/suite/__init__.py
Normal file
13
lib/sqlalchemy/testing/suite/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .test_cte import * # noqa
|
||||
from .test_ddl import * # noqa
|
||||
from .test_deprecations import * # noqa
|
||||
from .test_dialect import * # noqa
|
||||
from .test_insert import * # noqa
|
||||
from .test_reflection import * # noqa
|
||||
from .test_results import * # noqa
|
||||
from .test_rowcount import * # noqa
|
||||
from .test_select import * # noqa
|
||||
from .test_sequence import * # noqa
|
||||
from .test_types import * # noqa
|
||||
from .test_unicode_ddl import * # noqa
|
||||
from .test_update_delete import * # noqa
|
||||
204
lib/sqlalchemy/testing/suite/test_cte.py
Normal file
204
lib/sqlalchemy/testing/suite/test_cte.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import ForeignKey
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import String
|
||||
from ... import testing
|
||||
|
||||
|
||||
class CTETest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
__requires__ = ("ctes",)
|
||||
|
||||
run_inserts = "each"
|
||||
run_deletes = "each"
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
Column("parent_id", ForeignKey("some_table.id")),
|
||||
)
|
||||
|
||||
Table(
|
||||
"some_other_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
Column("parent_id", Integer),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1", "parent_id": None},
|
||||
{"id": 2, "data": "d2", "parent_id": 1},
|
||||
{"id": 3, "data": "d3", "parent_id": 1},
|
||||
{"id": 4, "data": "d4", "parent_id": 3},
|
||||
{"id": 5, "data": "d5", "parent_id": 3},
|
||||
],
|
||||
)
|
||||
|
||||
def test_select_nonrecursive_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
result = connection.execute(
|
||||
select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
|
||||
)
|
||||
eq_(result.fetchall(), [("d4",)])
|
||||
|
||||
def test_select_recursive_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte", recursive=True)
|
||||
)
|
||||
|
||||
cte_alias = cte.alias("c1")
|
||||
st1 = some_table.alias()
|
||||
# note that SQL Server requires this to be UNION ALL,
|
||||
# can't be UNION
|
||||
cte = cte.union_all(
|
||||
select(st1).where(st1.c.id == cte_alias.c.parent_id)
|
||||
)
|
||||
result = connection.execute(
|
||||
select(cte.c.data)
|
||||
.where(cte.c.data != "d2")
|
||||
.order_by(cte.c.data.desc())
|
||||
)
|
||||
eq_(
|
||||
result.fetchall(),
|
||||
[("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
|
||||
)
|
||||
|
||||
def test_insert_from_select_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(cte)
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
@testing.requires.update_from
|
||||
def test_update_from_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.update()
|
||||
.values(parent_id=5)
|
||||
.where(some_other_table.c.data == cte.c.data)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[
|
||||
(1, "d1", None),
|
||||
(2, "d2", 5),
|
||||
(3, "d3", 5),
|
||||
(4, "d4", 5),
|
||||
(5, "d5", 3),
|
||||
],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
@testing.requires.delete_from
|
||||
def test_delete_from_round_trip(self, connection):
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.delete().where(
|
||||
some_other_table.c.data == cte.c.data
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "d1", None), (5, "d5", 3)],
|
||||
)
|
||||
|
||||
@testing.requires.ctes_with_update_delete
|
||||
def test_delete_scalar_subq_round_trip(self, connection):
|
||||
|
||||
some_table = self.tables.some_table
|
||||
some_other_table = self.tables.some_other_table
|
||||
|
||||
connection.execute(
|
||||
some_other_table.insert().from_select(
|
||||
["id", "data", "parent_id"], select(some_table)
|
||||
)
|
||||
)
|
||||
|
||||
cte = (
|
||||
select(some_table)
|
||||
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
|
||||
.cte("some_cte")
|
||||
)
|
||||
connection.execute(
|
||||
some_other_table.delete().where(
|
||||
some_other_table.c.data
|
||||
== select(cte.c.data)
|
||||
.where(cte.c.id == some_other_table.c.id)
|
||||
.scalar_subquery()
|
||||
)
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(some_other_table).order_by(some_other_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "d1", None), (5, "d5", 3)],
|
||||
)
|
||||
381
lib/sqlalchemy/testing/suite/test_ddl.py
Normal file
381
lib/sqlalchemy/testing/suite/test_ddl.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import random
|
||||
|
||||
from . import testing
|
||||
from .. import config
|
||||
from .. import fixtures
|
||||
from .. import util
|
||||
from ..assertions import eq_
|
||||
from ..assertions import is_false
|
||||
from ..assertions import is_true
|
||||
from ..config import requirements
|
||||
from ..schema import Table
|
||||
from ... import CheckConstraint
|
||||
from ... import Column
|
||||
from ... import ForeignKeyConstraint
|
||||
from ... import Index
|
||||
from ... import inspect
|
||||
from ... import Integer
|
||||
from ... import schema
|
||||
from ... import String
|
||||
from ... import UniqueConstraint
|
||||
|
||||
|
||||
class TableDDLTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
def _simple_fixture(self, schema=None):
|
||||
return Table(
|
||||
"test_table",
|
||||
self.metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
def _underscore_fixture(self):
|
||||
return Table(
|
||||
"_test_table",
|
||||
self.metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("_data", String(50)),
|
||||
)
|
||||
|
||||
def _table_index_fixture(self, schema=None):
|
||||
table = self._simple_fixture(schema=schema)
|
||||
idx = Index("test_index", table.c.data)
|
||||
return table, idx
|
||||
|
||||
def _simple_roundtrip(self, table):
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(table.insert().values((1, "some data")))
|
||||
result = conn.execute(table.select())
|
||||
eq_(result.first(), (1, "some data"))
|
||||
|
||||
@requirements.create_table
|
||||
@util.provide_metadata
|
||||
def test_create_table(self):
|
||||
table = self._simple_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.create_table
|
||||
@requirements.schemas
|
||||
@util.provide_metadata
|
||||
def test_create_table_schema(self):
|
||||
table = self._simple_fixture(schema=config.test_schema)
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.drop_table
|
||||
@util.provide_metadata
|
||||
def test_drop_table(self):
|
||||
table = self._simple_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
table.drop(config.db, checkfirst=False)
|
||||
|
||||
@requirements.create_table
|
||||
@util.provide_metadata
|
||||
def test_underscore_names(self):
|
||||
table = self._underscore_fixture()
|
||||
table.create(config.db, checkfirst=False)
|
||||
self._simple_roundtrip(table)
|
||||
|
||||
@requirements.comment_reflection
|
||||
@util.provide_metadata
|
||||
def test_add_table_comment(self, connection):
|
||||
table = self._simple_fixture()
|
||||
table.create(connection, checkfirst=False)
|
||||
table.comment = "a comment"
|
||||
connection.execute(schema.SetTableComment(table))
|
||||
eq_(
|
||||
inspect(connection).get_table_comment("test_table"),
|
||||
{"text": "a comment"},
|
||||
)
|
||||
|
||||
@requirements.comment_reflection
|
||||
@util.provide_metadata
|
||||
def test_drop_table_comment(self, connection):
|
||||
table = self._simple_fixture()
|
||||
table.create(connection, checkfirst=False)
|
||||
table.comment = "a comment"
|
||||
connection.execute(schema.SetTableComment(table))
|
||||
connection.execute(schema.DropTableComment(table))
|
||||
eq_(
|
||||
inspect(connection).get_table_comment("test_table"), {"text": None}
|
||||
)
|
||||
|
||||
@requirements.table_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_create_table_if_not_exists(self, connection):
|
||||
table = self._simple_fixture()
|
||||
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
|
||||
@requirements.index_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_create_index_if_not_exists(self, connection):
|
||||
table, idx = self._table_index_fixture()
|
||||
|
||||
connection.execute(schema.CreateTable(table, if_not_exists=True))
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
is_false(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
|
||||
|
||||
is_true(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.CreateIndex(idx, if_not_exists=True))
|
||||
|
||||
@requirements.table_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_drop_table_if_exists(self, connection):
|
||||
table = self._simple_fixture()
|
||||
|
||||
table.create(connection)
|
||||
|
||||
is_true(inspect(connection).has_table("test_table"))
|
||||
|
||||
connection.execute(schema.DropTable(table, if_exists=True))
|
||||
|
||||
is_false(inspect(connection).has_table("test_table"))
|
||||
|
||||
connection.execute(schema.DropTable(table, if_exists=True))
|
||||
|
||||
@requirements.index_ddl_if_exists
|
||||
@util.provide_metadata
|
||||
def test_drop_index_if_exists(self, connection):
|
||||
table, idx = self._table_index_fixture()
|
||||
|
||||
table.create(connection)
|
||||
|
||||
is_true(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.DropIndex(idx, if_exists=True))
|
||||
|
||||
is_false(
|
||||
"test_index"
|
||||
in [
|
||||
ix["name"]
|
||||
for ix in inspect(connection).get_indexes("test_table")
|
||||
]
|
||||
)
|
||||
|
||||
connection.execute(schema.DropIndex(idx, if_exists=True))
|
||||
|
||||
|
||||
class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
|
||||
pass
|
||||
|
||||
|
||||
class LongNameBlowoutTest(fixtures.TestBase):
|
||||
"""test the creation of a variety of DDL structures and ensure
|
||||
label length limits pass on backends
|
||||
|
||||
"""
|
||||
|
||||
__backend__ = True
|
||||
|
||||
def fk(self, metadata, connection):
|
||||
convention = {
|
||||
"fk": "foreign_key_%(table_name)s_"
|
||||
"%(column_0_N_name)s_"
|
||||
"%(referred_table_name)s_"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(20))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
cons = ForeignKeyConstraint(
|
||||
["aid"], ["a_things_with_stuff.id_long_column_name"]
|
||||
)
|
||||
Table(
|
||||
"b_related_things_of_value",
|
||||
metadata,
|
||||
Column(
|
||||
"aid",
|
||||
),
|
||||
cons,
|
||||
test_needs_fk=True,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
|
||||
if testing.requires.foreign_key_constraint_name_reflection.enabled:
|
||||
insp = inspect(connection)
|
||||
fks = insp.get_foreign_keys("b_related_things_of_value")
|
||||
reflected_name = fks[0]["name"]
|
||||
|
||||
return actual_name, reflected_name
|
||||
else:
|
||||
return actual_name, None
|
||||
|
||||
def pk(self, metadata, connection):
|
||||
convention = {
|
||||
"pk": "primary_key_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
a = Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer, primary_key=True),
|
||||
)
|
||||
cons = a.primary_key
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
pk = insp.get_pk_constraint("a_things_with_stuff")
|
||||
reflected_name = pk["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def ix(self, metadata, connection):
|
||||
convention = {
|
||||
"ix": "index_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
a = Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer),
|
||||
)
|
||||
cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
ix = insp.get_indexes("a_things_with_stuff")
|
||||
reflected_name = ix[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def uq(self, metadata, connection):
|
||||
convention = {
|
||||
"uq": "unique_constraint_%(table_name)s_"
|
||||
"%(column_0_N_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("id_another_long_name", Integer),
|
||||
cons,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
uq = insp.get_unique_constraints("a_things_with_stuff")
|
||||
reflected_name = uq[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
def ck(self, metadata, connection):
|
||||
convention = {
|
||||
"ck": "check_constraint_%(table_name)s"
|
||||
+ (
|
||||
"_".join(
|
||||
"".join(random.choice("abcdef") for j in range(30))
|
||||
for i in range(10)
|
||||
)
|
||||
),
|
||||
}
|
||||
metadata.naming_convention = convention
|
||||
|
||||
cons = CheckConstraint("some_long_column_name > 5")
|
||||
Table(
|
||||
"a_things_with_stuff",
|
||||
metadata,
|
||||
Column("id_long_column_name", Integer, primary_key=True),
|
||||
Column("some_long_column_name", Integer),
|
||||
cons,
|
||||
)
|
||||
actual_name = cons.name
|
||||
|
||||
metadata.create_all(connection)
|
||||
insp = inspect(connection)
|
||||
ck = insp.get_check_constraints("a_things_with_stuff")
|
||||
reflected_name = ck[0]["name"]
|
||||
return actual_name, reflected_name
|
||||
|
||||
@testing.combinations(
|
||||
("fk",),
|
||||
("pk",),
|
||||
("ix",),
|
||||
("ck", testing.requires.check_constraint_reflection.as_skips()),
|
||||
("uq", testing.requires.unique_constraint_reflection.as_skips()),
|
||||
argnames="type_",
|
||||
)
|
||||
def test_long_convention_name(self, type_, metadata, connection):
|
||||
actual_name, reflected_name = getattr(self, type_)(
|
||||
metadata, connection
|
||||
)
|
||||
|
||||
assert len(actual_name) > 255
|
||||
|
||||
if reflected_name is not None:
|
||||
overlap = actual_name[0 : len(reflected_name)]
|
||||
if len(overlap) < len(actual_name):
|
||||
eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
|
||||
else:
|
||||
eq_(overlap, reflected_name)
|
||||
|
||||
|
||||
__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")
|
||||
145
lib/sqlalchemy/testing/suite/test_deprecations.py
Normal file
145
lib/sqlalchemy/testing/suite/test_deprecations.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import testing
|
||||
from ... import union
|
||||
|
||||
|
||||
class DeprecatedCompoundSelectTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("x", Integer),
|
||||
Column("y", Integer),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.some_table.insert(),
|
||||
[
|
||||
{"id": 1, "x": 1, "y": 2},
|
||||
{"id": 2, "x": 2, "y": 3},
|
||||
{"id": 3, "x": 3, "y": 4},
|
||||
{"id": 4, "x": 4, "y": 5},
|
||||
],
|
||||
)
|
||||
|
||||
def _assert_result(self, conn, select, result, params=()):
|
||||
eq_(conn.execute(select, params).fetchall(), result)
|
||||
|
||||
def test_plain_union(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2)
|
||||
s2 = select(table).where(table.c.id == 3)
|
||||
|
||||
u1 = union(s1, s2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
# note we've had to remove one use case entirely, which is this
|
||||
# one. the Select gets its FROMS from the WHERE clause and the
|
||||
# columns clause, but not the ORDER BY, which means the old ".c" system
|
||||
# allowed you to "order_by(s.c.foo)" to get an unnamed column in the
|
||||
# ORDER BY without adding the SELECT into the FROM and breaking the
|
||||
# query. Users will have to adjust for this use case if they were doing
|
||||
# it before.
|
||||
def _dont_test_select_from_plain_union(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2)
|
||||
s2 = select(table).where(table.c.id == 3)
|
||||
|
||||
u1 = union(s1, s2).alias().select()
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.order_by_col_from_union
|
||||
@testing.requires.parens_in_union_contained_select_w_limit_offset
|
||||
def test_limit_offset_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
|
||||
s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.parens_in_union_contained_select_wo_limit_offset
|
||||
def test_order_by_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
|
||||
s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_distinct_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = select(table).where(table.c.id == 2).distinct()
|
||||
s2 = select(table).where(table.c.id == 3).distinct()
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
|
||||
def test_limit_offset_aliased_selectable_in_unions(self, connection):
|
||||
table = self.tables.some_table
|
||||
s1 = (
|
||||
select(table)
|
||||
.where(table.c.id == 2)
|
||||
.limit(1)
|
||||
.order_by(table.c.id)
|
||||
.alias()
|
||||
.select()
|
||||
)
|
||||
s2 = (
|
||||
select(table)
|
||||
.where(table.c.id == 3)
|
||||
.limit(1)
|
||||
.order_by(table.c.id)
|
||||
.alias()
|
||||
.select()
|
||||
)
|
||||
|
||||
u1 = union(s1, s2).limit(2)
|
||||
with testing.expect_deprecated(
|
||||
"The SelectBase.c and SelectBase.columns "
|
||||
"attributes are deprecated"
|
||||
):
|
||||
self._assert_result(
|
||||
connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
|
||||
)
|
||||
361
lib/sqlalchemy/testing/suite/test_dialect.py
Normal file
361
lib/sqlalchemy/testing/suite/test_dialect.py
Normal file
@@ -0,0 +1,361 @@
|
||||
#! coding: utf-8
|
||||
|
||||
from . import testing
|
||||
from .. import assert_raises
|
||||
from .. import config
|
||||
from .. import engines
|
||||
from .. import eq_
|
||||
from .. import fixtures
|
||||
from .. import ne_
|
||||
from .. import provide_metadata
|
||||
from ..config import requirements
|
||||
from ..provision import set_default_schema_on_connection
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import bindparam
|
||||
from ... import event
|
||||
from ... import exc
|
||||
from ... import Integer
|
||||
from ... import literal_column
|
||||
from ... import select
|
||||
from ... import String
|
||||
from ...util import compat
|
||||
|
||||
|
||||
class ExceptionTest(fixtures.TablesTest):
|
||||
"""Test basic exception wrapping.
|
||||
|
||||
DBAPIs vary a lot in exception behavior so to actually anticipate
|
||||
specific exceptions from real round trips, we need to be conservative.
|
||||
|
||||
"""
|
||||
|
||||
run_deletes = "each"
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@requirements.duplicate_key_raises_integrity_error
|
||||
def test_integrity_error(self):
|
||||
|
||||
with config.db.connect() as conn:
|
||||
|
||||
trans = conn.begin()
|
||||
conn.execute(
|
||||
self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
|
||||
)
|
||||
|
||||
assert_raises(
|
||||
exc.IntegrityError,
|
||||
conn.execute,
|
||||
self.tables.manual_pk.insert(),
|
||||
{"id": 1, "data": "d1"},
|
||||
)
|
||||
|
||||
trans.rollback()
|
||||
|
||||
def test_exception_with_non_ascii(self):
|
||||
with config.db.connect() as conn:
|
||||
try:
|
||||
# try to create an error message that likely has non-ascii
|
||||
# characters in the DBAPI's message string. unfortunately
|
||||
# there's no way to make this happen with some drivers like
|
||||
# mysqlclient, pymysql. this at least does produce a non-
|
||||
# ascii error message for cx_oracle, psycopg2
|
||||
conn.execute(select(literal_column(u"méil")))
|
||||
assert False
|
||||
except exc.DBAPIError as err:
|
||||
err_str = str(err)
|
||||
|
||||
assert str(err.orig) in str(err)
|
||||
|
||||
# test that we are actually getting string on Py2k, unicode
|
||||
# on Py3k.
|
||||
if compat.py2k:
|
||||
assert isinstance(err_str, str)
|
||||
else:
|
||||
assert isinstance(err_str, str)
|
||||
|
||||
|
||||
class IsolationLevelTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = ("isolation_level",)
|
||||
|
||||
def _get_non_default_isolation_level(self):
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
|
||||
default = levels["default"]
|
||||
supported = levels["supported"]
|
||||
|
||||
s = set(supported).difference(["AUTOCOMMIT", default])
|
||||
if s:
|
||||
return s.pop()
|
||||
else:
|
||||
config.skip_test("no non-default isolation level available")
|
||||
|
||||
def test_default_isolation_level(self):
|
||||
eq_(
|
||||
config.db.dialect.default_isolation_level,
|
||||
requirements.get_isolation_levels(config)["default"],
|
||||
)
|
||||
|
||||
def test_non_default_isolation_level(self):
|
||||
non_default = self._get_non_default_isolation_level()
|
||||
|
||||
with config.db.connect() as conn:
|
||||
existing = conn.get_isolation_level()
|
||||
|
||||
ne_(existing, non_default)
|
||||
|
||||
conn.execution_options(isolation_level=non_default)
|
||||
|
||||
eq_(conn.get_isolation_level(), non_default)
|
||||
|
||||
conn.dialect.reset_isolation_level(conn.connection)
|
||||
|
||||
eq_(conn.get_isolation_level(), existing)
|
||||
|
||||
def test_all_levels(self):
|
||||
levels = requirements.get_isolation_levels(config)
|
||||
|
||||
all_levels = levels["supported"]
|
||||
|
||||
for level in set(all_levels).difference(["AUTOCOMMIT"]):
|
||||
with config.db.connect() as conn:
|
||||
conn.execution_options(isolation_level=level)
|
||||
|
||||
eq_(conn.get_isolation_level(), level)
|
||||
|
||||
trans = conn.begin()
|
||||
trans.rollback()
|
||||
|
||||
eq_(conn.get_isolation_level(), level)
|
||||
|
||||
with config.db.connect() as conn:
|
||||
eq_(
|
||||
conn.get_isolation_level(),
|
||||
levels["default"],
|
||||
)
|
||||
|
||||
|
||||
class AutocommitIsolationTest(fixtures.TablesTest):
|
||||
|
||||
run_deletes = "each"
|
||||
|
||||
__requires__ = ("autocommit",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"some_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
test_needs_acid=True,
|
||||
)
|
||||
|
||||
def _test_conn_autocommits(self, conn, autocommit):
|
||||
trans = conn.begin()
|
||||
conn.execute(
|
||||
self.tables.some_table.insert(), {"id": 1, "data": "some data"}
|
||||
)
|
||||
trans.rollback()
|
||||
|
||||
eq_(
|
||||
conn.scalar(select(self.tables.some_table.c.id)),
|
||||
1 if autocommit else None,
|
||||
)
|
||||
|
||||
with conn.begin():
|
||||
conn.execute(self.tables.some_table.delete())
|
||||
|
||||
def test_autocommit_on(self, connection_no_trans):
|
||||
conn = connection_no_trans
|
||||
c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
self._test_conn_autocommits(c2, True)
|
||||
|
||||
c2.dialect.reset_isolation_level(c2.connection)
|
||||
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
def test_autocommit_off(self, connection_no_trans):
|
||||
conn = connection_no_trans
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
def test_turn_autocommit_off_via_default_iso_level(
|
||||
self, connection_no_trans
|
||||
):
|
||||
conn = connection_no_trans
|
||||
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
|
||||
self._test_conn_autocommits(conn, True)
|
||||
|
||||
conn.execution_options(
|
||||
isolation_level=requirements.get_isolation_levels(config)[
|
||||
"default"
|
||||
]
|
||||
)
|
||||
self._test_conn_autocommits(conn, False)
|
||||
|
||||
|
||||
class EscapingTest(fixtures.TestBase):
|
||||
@provide_metadata
|
||||
def test_percent_sign_round_trip(self):
|
||||
"""test that the DBAPI accommodates for escaped / nonescaped
|
||||
percent signs in a way that matches the compiler
|
||||
|
||||
"""
|
||||
m = self.metadata
|
||||
t = Table("t", m, Column("data", String(50)))
|
||||
t.create(config.db)
|
||||
with config.db.begin() as conn:
|
||||
conn.execute(t.insert(), dict(data="some % value"))
|
||||
conn.execute(t.insert(), dict(data="some %% other value"))
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select(t.c.data).where(
|
||||
t.c.data == literal_column("'some % value'")
|
||||
)
|
||||
),
|
||||
"some % value",
|
||||
)
|
||||
|
||||
eq_(
|
||||
conn.scalar(
|
||||
select(t.c.data).where(
|
||||
t.c.data == literal_column("'some %% other value'")
|
||||
)
|
||||
),
|
||||
"some %% other value",
|
||||
)
|
||||
|
||||
|
||||
class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = ("default_schema_name_switch",)
|
||||
|
||||
def test_control_case(self):
|
||||
default_schema_name = config.db.dialect.default_schema_name
|
||||
|
||||
eng = engines.testing_engine()
|
||||
with eng.connect():
|
||||
pass
|
||||
|
||||
eq_(eng.dialect.default_schema_name, default_schema_name)
|
||||
|
||||
def test_wont_work_wo_insert(self):
|
||||
default_schema_name = config.db.dialect.default_schema_name
|
||||
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect")
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, default_schema_name)
|
||||
|
||||
def test_schema_change_on_connect(self):
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect", insert=True)
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, config.test_schema)
|
||||
|
||||
def test_schema_change_works_w_transactions(self):
|
||||
eng = engines.testing_engine()
|
||||
|
||||
@event.listens_for(eng, "connect", insert=True)
|
||||
def on_connect(dbapi_connection, *arg):
|
||||
set_default_schema_on_connection(
|
||||
config, dbapi_connection, config.test_schema
|
||||
)
|
||||
|
||||
with eng.connect() as conn:
|
||||
trans = conn.begin()
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
trans.rollback()
|
||||
|
||||
what_it_should_be = eng.dialect._get_default_schema_name(conn)
|
||||
eq_(what_it_should_be, config.test_schema)
|
||||
|
||||
eq_(eng.dialect.default_schema_name, config.test_schema)
|
||||
|
||||
|
||||
class FutureWeCanSetDefaultSchemaWEventsTest(
|
||||
fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class DifficultParametersTest(fixtures.TestBase):
|
||||
__backend__ = True
|
||||
|
||||
@testing.combinations(
|
||||
("boring",),
|
||||
("per cent",),
|
||||
("per % cent",),
|
||||
("%percent",),
|
||||
("par(ens)",),
|
||||
("percent%(ens)yah",),
|
||||
("col:ons",),
|
||||
("more :: %colons%",),
|
||||
("/slashes/",),
|
||||
("more/slashes",),
|
||||
("q?marks",),
|
||||
("1param",),
|
||||
("1col:on",),
|
||||
argnames="name",
|
||||
)
|
||||
def test_round_trip(self, name, connection, metadata):
|
||||
t = Table(
|
||||
"t",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column(name, String(50), nullable=False),
|
||||
)
|
||||
|
||||
# table is created
|
||||
t.create(connection)
|
||||
|
||||
# automatic param generated by insert
|
||||
connection.execute(t.insert().values({"id": 1, name: "some name"}))
|
||||
|
||||
# automatic param generated by criteria, plus selecting the column
|
||||
stmt = select(t.c[name]).where(t.c[name] == "some name")
|
||||
|
||||
eq_(connection.scalar(stmt), "some name")
|
||||
|
||||
# use the name in a param explicitly
|
||||
stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
|
||||
|
||||
row = connection.execute(stmt, {name: "some name"}).first()
|
||||
|
||||
# name works as the key from cursor.description
|
||||
eq_(row._mapping[name], "some name")
|
||||
367
lib/sqlalchemy/testing/suite/test_insert.py
Normal file
367
lib/sqlalchemy/testing/suite/test_insert.py
Normal file
@@ -0,0 +1,367 @@
|
||||
from .. import config
|
||||
from .. import engines
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..config import requirements
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Integer
|
||||
from ... import literal
|
||||
from ... import literal_column
|
||||
from ... import select
|
||||
from ... import String
|
||||
|
||||
|
||||
class LastrowidTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
|
||||
__backend__ = True
|
||||
|
||||
__requires__ = "implements_get_lastrowid", "autoincrement_insert"
|
||||
|
||||
__engine_options__ = {"implicit_returning": False}
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(
|
||||
conn.dialect.default_sequence_base,
|
||||
"some data",
|
||||
),
|
||||
)
|
||||
|
||||
def test_autoincrement_on_insert(self, connection):
|
||||
|
||||
connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.autoinc_pk, connection)
|
||||
|
||||
def test_last_inserted_id(self, connection):
|
||||
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(r.inserted_primary_key, (pk,))
|
||||
|
||||
@requirements.dbapi_lastrowid
|
||||
def test_native_lastrowid_autoinc(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
lastrowid = r.lastrowid
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(lastrowid, pk)
|
||||
|
||||
|
||||
class InsertBehaviorTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"manual_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True, autoincrement=False),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"includes_defaults",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
Column("x", Integer, default=5),
|
||||
Column(
|
||||
"y",
|
||||
Integer,
|
||||
default=literal_column("2", type_=Integer) + literal(2),
|
||||
),
|
||||
)
|
||||
|
||||
@requirements.autoincrement_insert
|
||||
def test_autoclose_on_insert(self):
|
||||
if requirements.returning.enabled:
|
||||
engine = engines.testing_engine(
|
||||
options={"implicit_returning": False}
|
||||
)
|
||||
else:
|
||||
engine = config.db
|
||||
|
||||
with engine.begin() as conn:
|
||||
r = conn.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
|
||||
# new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
|
||||
# an insert where the PK was taken from a row that the dialect
|
||||
# selected, as is the case for mssql/pyodbc, will still report
|
||||
# returns_rows as true because there's a cursor description. in that
|
||||
# case, the row had to have been consumed at least.
|
||||
assert not r.returns_rows or r.fetchone() is None
|
||||
|
||||
@requirements.returning
|
||||
def test_autoclose_on_insert_implicit_returning(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
assert r.is_insert
|
||||
|
||||
# note we are experimenting with having this be True
|
||||
# as of I8091919d45421e3f53029b8660427f844fee0228 .
|
||||
# implicit returning has fetched the row, but it still is a
|
||||
# "returns rows"
|
||||
assert r.returns_rows
|
||||
|
||||
# and we should be able to fetchone() on it, we just get no row
|
||||
eq_(r.fetchone(), None)
|
||||
|
||||
# and the keys, etc.
|
||||
eq_(r.keys(), ["id"])
|
||||
|
||||
# but the dialect took in the row already. not really sure
|
||||
# what the best behavior is.
|
||||
|
||||
@requirements.empty_inserts
|
||||
def test_empty_insert(self, connection):
|
||||
r = connection.execute(self.tables.autoinc_pk.insert())
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.select().where(
|
||||
self.tables.autoinc_pk.c.id != None
|
||||
)
|
||||
)
|
||||
eq_(len(r.all()), 1)
|
||||
|
||||
@requirements.empty_inserts_executemany
|
||||
def test_empty_insert_multiple(self, connection):
|
||||
r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
|
||||
assert r._soft_closed
|
||||
assert not r.closed
|
||||
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.select().where(
|
||||
self.tables.autoinc_pk.c.id != None
|
||||
)
|
||||
)
|
||||
|
||||
eq_(len(r.all()), 3)
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_autoinc(self, connection):
|
||||
src_table = self.tables.manual_pk
|
||||
dest_table = self.tables.autoinc_pk
|
||||
connection.execute(
|
||||
src_table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
result = connection.execute(
|
||||
dest_table.insert().from_select(
|
||||
("data",),
|
||||
select(src_table.c.data).where(
|
||||
src_table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(result.inserted_primary_key, (None,))
|
||||
|
||||
result = connection.execute(
|
||||
select(dest_table.c.data).order_by(dest_table.c.data)
|
||||
)
|
||||
eq_(result.fetchall(), [("data2",), ("data3",)])
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_autoinc_no_rows(self, connection):
|
||||
src_table = self.tables.manual_pk
|
||||
dest_table = self.tables.autoinc_pk
|
||||
|
||||
result = connection.execute(
|
||||
dest_table.insert().from_select(
|
||||
("data",),
|
||||
select(src_table.c.data).where(
|
||||
src_table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
eq_(result.inserted_primary_key, (None,))
|
||||
|
||||
result = connection.execute(
|
||||
select(dest_table.c.data).order_by(dest_table.c.data)
|
||||
)
|
||||
|
||||
eq_(result.fetchall(), [])
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select(self, connection):
|
||||
table = self.tables.manual_pk
|
||||
connection.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
table.insert()
|
||||
.inline()
|
||||
.from_select(
|
||||
("id", "data"),
|
||||
select(table.c.id + 5, table.c.data).where(
|
||||
table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(table.c.data).order_by(table.c.data)
|
||||
).fetchall(),
|
||||
[("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
|
||||
)
|
||||
|
||||
@requirements.insert_from_select
|
||||
def test_insert_from_select_with_defaults(self, connection):
|
||||
table = self.tables.includes_defaults
|
||||
connection.execute(
|
||||
table.insert(),
|
||||
[
|
||||
dict(id=1, data="data1"),
|
||||
dict(id=2, data="data2"),
|
||||
dict(id=3, data="data3"),
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
table.insert()
|
||||
.inline()
|
||||
.from_select(
|
||||
("id", "data"),
|
||||
select(table.c.id + 5, table.c.data).where(
|
||||
table.c.data.in_(["data2", "data3"])
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(
|
||||
select(table).order_by(table.c.data, table.c.id)
|
||||
).fetchall(),
|
||||
[
|
||||
(1, "data1", 5, 4),
|
||||
(2, "data2", 5, 4),
|
||||
(7, "data2", 5, 4),
|
||||
(3, "data3", 5, 4),
|
||||
(8, "data3", 5, 4),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ReturningTest(fixtures.TablesTest):
|
||||
run_create_tables = "each"
|
||||
__requires__ = "returning", "autoincrement_insert"
|
||||
__backend__ = True
|
||||
|
||||
__engine_options__ = {"implicit_returning": True}
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(
|
||||
row,
|
||||
(
|
||||
conn.dialect.default_sequence_base,
|
||||
"some data",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"autoinc_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@requirements.fetch_rows_post_commit
|
||||
def test_explicit_returning_pk_autocommit(self, connection):
|
||||
table = self.tables.autoinc_pk
|
||||
r = connection.execute(
|
||||
table.insert().returning(table.c.id), dict(data="some data")
|
||||
)
|
||||
pk = r.first()[0]
|
||||
fetched_pk = connection.scalar(select(table.c.id))
|
||||
eq_(fetched_pk, pk)
|
||||
|
||||
def test_explicit_returning_pk_no_autocommit(self, connection):
|
||||
table = self.tables.autoinc_pk
|
||||
r = connection.execute(
|
||||
table.insert().returning(table.c.id), dict(data="some data")
|
||||
)
|
||||
pk = r.first()[0]
|
||||
fetched_pk = connection.scalar(select(table.c.id))
|
||||
eq_(fetched_pk, pk)
|
||||
|
||||
def test_autoincrement_on_insert_implicit_returning(self, connection):
|
||||
|
||||
connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.autoinc_pk, connection)
|
||||
|
||||
def test_last_inserted_id_implicit_returning(self, connection):
|
||||
|
||||
r = connection.execute(
|
||||
self.tables.autoinc_pk.insert(), dict(data="some data")
|
||||
)
|
||||
pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
|
||||
eq_(r.inserted_primary_key, (pk,))
|
||||
|
||||
|
||||
__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
|
||||
1738
lib/sqlalchemy/testing/suite/test_reflection.py
Normal file
1738
lib/sqlalchemy/testing/suite/test_reflection.py
Normal file
File diff suppressed because it is too large
Load Diff
426
lib/sqlalchemy/testing/suite/test_results.py
Normal file
426
lib/sqlalchemy/testing/suite/test_results.py
Normal file
@@ -0,0 +1,426 @@
|
||||
import datetime
|
||||
|
||||
from .. import engines
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..config import requirements
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import DateTime
|
||||
from ... import func
|
||||
from ... import Integer
|
||||
from ... import select
|
||||
from ... import sql
|
||||
from ... import String
|
||||
from ... import testing
|
||||
from ... import text
|
||||
from ... import util
|
||||
|
||||
|
||||
class RowFetchTest(fixtures.TablesTest):
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"plain_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
Table(
|
||||
"has_dates",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("today", DateTime),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.plain_pk.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
cls.tables.has_dates.insert(),
|
||||
[{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
|
||||
)
|
||||
|
||||
def test_via_attr(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row.id, 1)
|
||||
eq_(row.data, "d1")
|
||||
|
||||
def test_via_string(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row._mapping["id"], 1)
|
||||
eq_(row._mapping["data"], "d1")
|
||||
|
||||
def test_via_int(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row[0], 1)
|
||||
eq_(row[1], "d1")
|
||||
|
||||
def test_via_col_object(self, connection):
|
||||
row = connection.execute(
|
||||
self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
|
||||
).first()
|
||||
|
||||
eq_(row._mapping[self.tables.plain_pk.c.id], 1)
|
||||
eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
|
||||
|
||||
@requirements.duplicate_names_in_cursor_description
|
||||
def test_row_with_dupe_names(self, connection):
|
||||
result = connection.execute(
|
||||
select(
|
||||
self.tables.plain_pk.c.data,
|
||||
self.tables.plain_pk.c.data.label("data"),
|
||||
).order_by(self.tables.plain_pk.c.id)
|
||||
)
|
||||
row = result.first()
|
||||
eq_(result.keys(), ["data", "data"])
|
||||
eq_(row, ("d1", "d1"))
|
||||
|
||||
def test_row_w_scalar_select(self, connection):
|
||||
"""test that a scalar select as a column is returned as such
|
||||
and that type conversion works OK.
|
||||
|
||||
(this is half a SQLAlchemy Core test and half to catch database
|
||||
backends that may have unusual behavior with scalar selects.)
|
||||
|
||||
"""
|
||||
datetable = self.tables.has_dates
|
||||
s = select(datetable.alias("x").c.today).scalar_subquery()
|
||||
s2 = select(datetable.c.id, s.label("somelabel"))
|
||||
row = connection.execute(s2).first()
|
||||
|
||||
eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
|
||||
|
||||
|
||||
class PercentSchemaNamesTest(fixtures.TablesTest):
|
||||
"""tests using percent signs, spaces in table and column names.
|
||||
|
||||
This didn't work for PostgreSQL / MySQL drivers for a long time
|
||||
but is now supported.
|
||||
|
||||
"""
|
||||
|
||||
__requires__ = ("percent_schema_names",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
cls.tables.percent_table = Table(
|
||||
"percent%table",
|
||||
metadata,
|
||||
Column("percent%", Integer),
|
||||
Column("spaces % more spaces", Integer),
|
||||
)
|
||||
cls.tables.lightweight_percent_table = sql.table(
|
||||
"percent%table",
|
||||
sql.column("percent%"),
|
||||
sql.column("spaces % more spaces"),
|
||||
)
|
||||
|
||||
def test_single_roundtrip(self, connection):
|
||||
percent_table = self.tables.percent_table
|
||||
for params in [
|
||||
{"percent%": 5, "spaces % more spaces": 12},
|
||||
{"percent%": 7, "spaces % more spaces": 11},
|
||||
{"percent%": 9, "spaces % more spaces": 10},
|
||||
{"percent%": 11, "spaces % more spaces": 9},
|
||||
]:
|
||||
connection.execute(percent_table.insert(), params)
|
||||
self._assert_table(connection)
|
||||
|
||||
def test_executemany_roundtrip(self, connection):
|
||||
percent_table = self.tables.percent_table
|
||||
connection.execute(
|
||||
percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
|
||||
)
|
||||
connection.execute(
|
||||
percent_table.insert(),
|
||||
[
|
||||
{"percent%": 7, "spaces % more spaces": 11},
|
||||
{"percent%": 9, "spaces % more spaces": 10},
|
||||
{"percent%": 11, "spaces % more spaces": 9},
|
||||
],
|
||||
)
|
||||
self._assert_table(connection)
|
||||
|
||||
def _assert_table(self, conn):
|
||||
percent_table = self.tables.percent_table
|
||||
lightweight_percent_table = self.tables.lightweight_percent_table
|
||||
|
||||
for table in (
|
||||
percent_table,
|
||||
percent_table.alias(),
|
||||
lightweight_percent_table,
|
||||
lightweight_percent_table.alias(),
|
||||
):
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(table.select().order_by(table.c["percent%"]))
|
||||
),
|
||||
[(5, 12), (7, 11), (9, 10), (11, 9)],
|
||||
)
|
||||
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(
|
||||
table.select()
|
||||
.where(table.c["spaces % more spaces"].in_([9, 10]))
|
||||
.order_by(table.c["percent%"])
|
||||
)
|
||||
),
|
||||
[(9, 10), (11, 9)],
|
||||
)
|
||||
|
||||
row = conn.execute(
|
||||
table.select().order_by(table.c["percent%"])
|
||||
).first()
|
||||
eq_(row._mapping["percent%"], 5)
|
||||
eq_(row._mapping["spaces % more spaces"], 12)
|
||||
|
||||
eq_(row._mapping[table.c["percent%"]], 5)
|
||||
eq_(row._mapping[table.c["spaces % more spaces"]], 12)
|
||||
|
||||
conn.execute(
|
||||
percent_table.update().values(
|
||||
{percent_table.c["spaces % more spaces"]: 15}
|
||||
)
|
||||
)
|
||||
|
||||
eq_(
|
||||
list(
|
||||
conn.execute(
|
||||
percent_table.select().order_by(
|
||||
percent_table.c["percent%"]
|
||||
)
|
||||
)
|
||||
),
|
||||
[(5, 15), (7, 15), (9, 15), (11, 15)],
|
||||
)
|
||||
|
||||
|
||||
class ServerSideCursorsTest(
|
||||
fixtures.TestBase, testing.AssertsExecutionResults
|
||||
):
|
||||
|
||||
__requires__ = ("server_side_cursors",)
|
||||
|
||||
__backend__ = True
|
||||
|
||||
def _is_server_side(self, cursor):
|
||||
# TODO: this is a huge issue as it prevents these tests from being
|
||||
# usable by third party dialects.
|
||||
if self.engine.dialect.driver == "psycopg2":
|
||||
return bool(cursor.name)
|
||||
elif self.engine.dialect.driver == "pymysql":
|
||||
sscursor = __import__("pymysql.cursors").cursors.SSCursor
|
||||
return isinstance(cursor, sscursor)
|
||||
elif self.engine.dialect.driver in ("aiomysql", "asyncmy"):
|
||||
return cursor.server_side
|
||||
elif self.engine.dialect.driver == "mysqldb":
|
||||
sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
|
||||
return isinstance(cursor, sscursor)
|
||||
elif self.engine.dialect.driver == "mariadbconnector":
|
||||
return not cursor.buffered
|
||||
elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
|
||||
return cursor.server_side
|
||||
elif self.engine.dialect.driver == "pg8000":
|
||||
return getattr(cursor, "server_side", False)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _fixture(self, server_side_cursors):
|
||||
if server_side_cursors:
|
||||
with testing.expect_deprecated(
|
||||
"The create_engine.server_side_cursors parameter is "
|
||||
"deprecated and will be removed in a future release. "
|
||||
"Please use the Connection.execution_options.stream_results "
|
||||
"parameter."
|
||||
):
|
||||
self.engine = engines.testing_engine(
|
||||
options={"server_side_cursors": server_side_cursors}
|
||||
)
|
||||
else:
|
||||
self.engine = engines.testing_engine(
|
||||
options={"server_side_cursors": server_side_cursors}
|
||||
)
|
||||
return self.engine
|
||||
|
||||
@testing.combinations(
|
||||
("global_string", True, "select 1", True),
|
||||
("global_text", True, text("select 1"), True),
|
||||
("global_expr", True, select(1), True),
|
||||
("global_off_explicit", False, text("select 1"), False),
|
||||
(
|
||||
"stmt_option",
|
||||
False,
|
||||
select(1).execution_options(stream_results=True),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"stmt_option_disabled",
|
||||
True,
|
||||
select(1).execution_options(stream_results=False),
|
||||
False,
|
||||
),
|
||||
("for_update_expr", True, select(1).with_for_update(), True),
|
||||
# TODO: need a real requirement for this, or dont use this test
|
||||
(
|
||||
"for_update_string",
|
||||
True,
|
||||
"SELECT 1 FOR UPDATE",
|
||||
True,
|
||||
testing.skip_if("sqlite"),
|
||||
),
|
||||
("text_no_ss", False, text("select 42"), False),
|
||||
(
|
||||
"text_ss_option",
|
||||
False,
|
||||
text("select 42").execution_options(stream_results=True),
|
||||
True,
|
||||
),
|
||||
id_="iaaa",
|
||||
argnames="engine_ss_arg, statement, cursor_ss_status",
|
||||
)
|
||||
def test_ss_cursor_status(
|
||||
self, engine_ss_arg, statement, cursor_ss_status
|
||||
):
|
||||
engine = self._fixture(engine_ss_arg)
|
||||
with engine.begin() as conn:
|
||||
if isinstance(statement, util.string_types):
|
||||
result = conn.exec_driver_sql(statement)
|
||||
else:
|
||||
result = conn.execute(statement)
|
||||
eq_(self._is_server_side(result.cursor), cursor_ss_status)
|
||||
result.close()
|
||||
|
||||
def test_conn_option(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
with engine.connect() as conn:
|
||||
# should be enabled for this one
|
||||
result = conn.execution_options(
|
||||
stream_results=True
|
||||
).exec_driver_sql("select 1")
|
||||
assert self._is_server_side(result.cursor)
|
||||
|
||||
def test_stmt_enabled_conn_option_disabled(self):
|
||||
engine = self._fixture(False)
|
||||
|
||||
s = select(1).execution_options(stream_results=True)
|
||||
|
||||
with engine.connect() as conn:
|
||||
# not this one
|
||||
result = conn.execution_options(stream_results=False).execute(s)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
|
||||
def test_aliases_and_ss(self):
|
||||
engine = self._fixture(False)
|
||||
s1 = (
|
||||
select(sql.literal_column("1").label("x"))
|
||||
.execution_options(stream_results=True)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
# options don't propagate out when subquery is used as a FROM clause
|
||||
with engine.begin() as conn:
|
||||
result = conn.execute(s1.select())
|
||||
assert not self._is_server_side(result.cursor)
|
||||
result.close()
|
||||
|
||||
s2 = select(1).select_from(s1)
|
||||
with engine.begin() as conn:
|
||||
result = conn.execute(s2)
|
||||
assert not self._is_server_side(result.cursor)
|
||||
result.close()
|
||||
|
||||
def test_roundtrip_fetchall(self, metadata):
|
||||
md = self.metadata
|
||||
|
||||
engine = self._fixture(True)
|
||||
test_table = Table(
|
||||
"test_table",
|
||||
md,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
with engine.begin() as connection:
|
||||
test_table.create(connection, checkfirst=True)
|
||||
connection.execute(test_table.insert(), dict(data="data1"))
|
||||
connection.execute(test_table.insert(), dict(data="data2"))
|
||||
eq_(
|
||||
connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "data1"), (2, "data2")],
|
||||
)
|
||||
connection.execute(
|
||||
test_table.update()
|
||||
.where(test_table.c.id == 2)
|
||||
.values(data=test_table.c.data + " updated")
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
).fetchall(),
|
||||
[(1, "data1"), (2, "data2 updated")],
|
||||
)
|
||||
connection.execute(test_table.delete())
|
||||
eq_(
|
||||
connection.scalar(
|
||||
select(func.count("*")).select_from(test_table)
|
||||
),
|
||||
0,
|
||||
)
|
||||
|
||||
def test_roundtrip_fetchmany(self, metadata):
|
||||
md = self.metadata
|
||||
|
||||
engine = self._fixture(True)
|
||||
test_table = Table(
|
||||
"test_table",
|
||||
md,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
with engine.begin() as connection:
|
||||
test_table.create(connection, checkfirst=True)
|
||||
connection.execute(
|
||||
test_table.insert(),
|
||||
[dict(data="data%d" % i) for i in range(1, 20)],
|
||||
)
|
||||
|
||||
result = connection.execute(
|
||||
test_table.select().order_by(test_table.c.id)
|
||||
)
|
||||
|
||||
eq_(
|
||||
result.fetchmany(5),
|
||||
[(i, "data%d" % i) for i in range(1, 6)],
|
||||
)
|
||||
eq_(
|
||||
result.fetchmany(10),
|
||||
[(i, "data%d" % i) for i in range(6, 16)],
|
||||
)
|
||||
eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
|
||||
165
lib/sqlalchemy/testing/suite/test_rowcount.py
Normal file
165
lib/sqlalchemy/testing/suite/test_rowcount.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from sqlalchemy import bindparam
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
|
||||
|
||||
class RowCountTest(fixtures.TablesTest):
|
||||
"""test rowcount functionality"""
|
||||
|
||||
__requires__ = ("sane_rowcount",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"employees",
|
||||
metadata,
|
||||
Column(
|
||||
"employee_id",
|
||||
Integer,
|
||||
autoincrement=False,
|
||||
primary_key=True,
|
||||
),
|
||||
Column("name", String(50)),
|
||||
Column("department", String(1)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
cls.data = data = [
|
||||
("Angela", "A"),
|
||||
("Andrew", "A"),
|
||||
("Anand", "A"),
|
||||
("Bob", "B"),
|
||||
("Bobette", "B"),
|
||||
("Buffy", "B"),
|
||||
("Charlie", "C"),
|
||||
("Cynthia", "C"),
|
||||
("Chris", "C"),
|
||||
]
|
||||
|
||||
employees_table = cls.tables.employees
|
||||
connection.execute(
|
||||
employees_table.insert(),
|
||||
[
|
||||
{"employee_id": i, "name": n, "department": d}
|
||||
for i, (n, d) in enumerate(data)
|
||||
],
|
||||
)
|
||||
|
||||
def test_basic(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
s = select(
|
||||
employees_table.c.name, employees_table.c.department
|
||||
).order_by(employees_table.c.employee_id)
|
||||
rows = connection.execute(s).fetchall()
|
||||
|
||||
eq_(rows, self.data)
|
||||
|
||||
def test_update_rowcount1(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 3 rows changed
|
||||
department = employees_table.c.department
|
||||
r = connection.execute(
|
||||
employees_table.update().where(department == "C"),
|
||||
{"department": "Z"},
|
||||
)
|
||||
assert r.rowcount == 3
|
||||
|
||||
def test_update_rowcount2(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 0 rows changed
|
||||
department = employees_table.c.department
|
||||
|
||||
r = connection.execute(
|
||||
employees_table.update().where(department == "C"),
|
||||
{"department": "C"},
|
||||
)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
@testing.requires.sane_rowcount_w_returning
|
||||
def test_update_rowcount_return_defaults(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
department = employees_table.c.department
|
||||
stmt = (
|
||||
employees_table.update()
|
||||
.where(department == "C")
|
||||
.values(name=employees_table.c.department + "Z")
|
||||
.return_defaults()
|
||||
)
|
||||
|
||||
r = connection.execute(stmt)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
def test_raw_sql_rowcount(self, connection):
|
||||
# test issue #3622, make sure eager rowcount is called for text
|
||||
result = connection.exec_driver_sql(
|
||||
"update employees set department='Z' where department='C'"
|
||||
)
|
||||
eq_(result.rowcount, 3)
|
||||
|
||||
def test_text_rowcount(self, connection):
|
||||
# test issue #3622, make sure eager rowcount is called for text
|
||||
result = connection.execute(
|
||||
text("update employees set department='Z' " "where department='C'")
|
||||
)
|
||||
eq_(result.rowcount, 3)
|
||||
|
||||
def test_delete_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
# WHERE matches 3, 3 rows deleted
|
||||
department = employees_table.c.department
|
||||
r = connection.execute(
|
||||
employees_table.delete().where(department == "C")
|
||||
)
|
||||
eq_(r.rowcount, 3)
|
||||
|
||||
@testing.requires.sane_multi_rowcount
|
||||
def test_multi_update_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
stmt = (
|
||||
employees_table.update()
|
||||
.where(employees_table.c.name == bindparam("emp_name"))
|
||||
.values(department="C")
|
||||
)
|
||||
|
||||
r = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{"emp_name": "Bob"},
|
||||
{"emp_name": "Cynthia"},
|
||||
{"emp_name": "nonexistent"},
|
||||
],
|
||||
)
|
||||
|
||||
eq_(r.rowcount, 2)
|
||||
|
||||
@testing.requires.sane_multi_rowcount
|
||||
def test_multi_delete_rowcount(self, connection):
|
||||
employees_table = self.tables.employees
|
||||
|
||||
stmt = employees_table.delete().where(
|
||||
employees_table.c.name == bindparam("emp_name")
|
||||
)
|
||||
|
||||
r = connection.execute(
|
||||
stmt,
|
||||
[
|
||||
{"emp_name": "Bob"},
|
||||
{"emp_name": "Cynthia"},
|
||||
{"emp_name": "nonexistent"},
|
||||
],
|
||||
)
|
||||
|
||||
eq_(r.rowcount, 2)
|
||||
1783
lib/sqlalchemy/testing/suite/test_select.py
Normal file
1783
lib/sqlalchemy/testing/suite/test_select.py
Normal file
File diff suppressed because it is too large
Load Diff
282
lib/sqlalchemy/testing/suite/test_sequence.py
Normal file
282
lib/sqlalchemy/testing/suite/test_sequence.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from .. import config
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..assertions import is_true
|
||||
from ..config import requirements
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import inspect
|
||||
from ... import Integer
|
||||
from ... import MetaData
|
||||
from ... import Sequence
|
||||
from ... import String
|
||||
from ... import testing
|
||||
|
||||
|
||||
class SequenceTest(fixtures.TablesTest):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
run_create_tables = "each"
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"seq_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
Sequence("tab_id_seq"),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
Table(
|
||||
"seq_opt_pk",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
Sequence("tab_id_seq", data_type=Integer, optional=True),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
Table(
|
||||
"seq_no_returning",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
Sequence("noret_id_seq"),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
)
|
||||
|
||||
if testing.requires.schemas.enabled:
|
||||
Table(
|
||||
"seq_no_returning_sch",
|
||||
metadata,
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
Sequence("noret_sch_id_seq", schema=config.test_schema),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=False,
|
||||
schema=config.test_schema,
|
||||
)
|
||||
|
||||
def test_insert_roundtrip(self, connection):
|
||||
connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
|
||||
self._assert_round_trip(self.tables.seq_pk, connection)
|
||||
|
||||
def test_insert_lastrowid(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.seq_pk.insert(), dict(data="some data")
|
||||
)
|
||||
eq_(
|
||||
r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
|
||||
)
|
||||
|
||||
def test_nextval_direct(self, connection):
|
||||
r = connection.execute(self.tables.seq_pk.c.id.default)
|
||||
eq_(r, testing.db.dialect.default_sequence_base)
|
||||
|
||||
@requirements.sequences_optional
|
||||
def test_optional_seq(self, connection):
|
||||
r = connection.execute(
|
||||
self.tables.seq_opt_pk.insert(), dict(data="some data")
|
||||
)
|
||||
eq_(r.inserted_primary_key, (1,))
|
||||
|
||||
def _assert_round_trip(self, table, conn):
|
||||
row = conn.execute(table.select()).first()
|
||||
eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
|
||||
|
||||
def test_insert_roundtrip_no_implicit_returning(self, connection):
|
||||
connection.execute(
|
||||
self.tables.seq_no_returning.insert(), dict(data="some data")
|
||||
)
|
||||
self._assert_round_trip(self.tables.seq_no_returning, connection)
|
||||
|
||||
@testing.combinations((True,), (False,), argnames="implicit_returning")
|
||||
@testing.requires.schemas
|
||||
def test_insert_roundtrip_translate(self, connection, implicit_returning):
|
||||
|
||||
seq_no_returning = Table(
|
||||
"seq_no_returning_sch",
|
||||
MetaData(),
|
||||
Column(
|
||||
"id",
|
||||
Integer,
|
||||
Sequence("noret_sch_id_seq", schema="alt_schema"),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("data", String(50)),
|
||||
implicit_returning=implicit_returning,
|
||||
schema="alt_schema",
|
||||
)
|
||||
|
||||
connection = connection.execution_options(
|
||||
schema_translate_map={"alt_schema": config.test_schema}
|
||||
)
|
||||
connection.execute(seq_no_returning.insert(), dict(data="some data"))
|
||||
self._assert_round_trip(seq_no_returning, connection)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_nextval_direct_schema_translate(self, connection):
|
||||
seq = Sequence("noret_sch_id_seq", schema="alt_schema")
|
||||
connection = connection.execution_options(
|
||||
schema_translate_map={"alt_schema": config.test_schema}
|
||||
)
|
||||
|
||||
r = connection.execute(seq)
|
||||
eq_(r, testing.db.dialect.default_sequence_base)
|
||||
|
||||
|
||||
class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
def test_literal_binds_inline_compile(self, connection):
|
||||
table = Table(
|
||||
"x",
|
||||
MetaData(),
|
||||
Column("y", Integer, Sequence("y_seq")),
|
||||
Column("q", Integer),
|
||||
)
|
||||
|
||||
stmt = table.insert().values(q=5)
|
||||
|
||||
seq_nextval = connection.dialect.statement_compiler(
|
||||
statement=None, dialect=connection.dialect
|
||||
).visit_sequence(Sequence("y_seq"))
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
|
||||
literal_binds=True,
|
||||
dialect=connection.dialect,
|
||||
)
|
||||
|
||||
|
||||
class HasSequenceTest(fixtures.TablesTest):
|
||||
run_deletes = None
|
||||
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Sequence("user_id_seq", metadata=metadata)
|
||||
Sequence(
|
||||
"other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True
|
||||
)
|
||||
if testing.requires.schemas.enabled:
|
||||
Sequence(
|
||||
"user_id_seq", schema=config.test_schema, metadata=metadata
|
||||
)
|
||||
Sequence(
|
||||
"schema_seq", schema=config.test_schema, metadata=metadata
|
||||
)
|
||||
Table(
|
||||
"user_id_table",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
)
|
||||
|
||||
def test_has_sequence(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence("user_id_seq"),
|
||||
True,
|
||||
)
|
||||
|
||||
def test_has_sequence_other_object(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence("user_id_table"),
|
||||
False,
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schema(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"user_id_seq", schema=config.test_schema
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
def test_has_sequence_neg(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence("some_sequence"),
|
||||
False,
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_schemas_neg(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"some_sequence", schema=config.test_schema
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_default_not_in_remote(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence(
|
||||
"other_sequence", schema=config.test_schema
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_has_sequence_remote_not_in_default(self, connection):
|
||||
eq_(
|
||||
inspect(connection).has_sequence("schema_seq"),
|
||||
False,
|
||||
)
|
||||
|
||||
def test_get_sequence_names(self, connection):
|
||||
exp = {"other_seq", "user_id_seq"}
|
||||
|
||||
res = set(inspect(connection).get_sequence_names())
|
||||
is_true(res.intersection(exp) == exp)
|
||||
is_true("schema_seq" not in res)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_get_sequence_names_no_sequence_schema(self, connection):
|
||||
eq_(
|
||||
inspect(connection).get_sequence_names(
|
||||
schema=config.test_schema_2
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
@testing.requires.schemas
|
||||
def test_get_sequence_names_sequences_schema(self, connection):
|
||||
eq_(
|
||||
sorted(
|
||||
inspect(connection).get_sequence_names(
|
||||
schema=config.test_schema
|
||||
)
|
||||
),
|
||||
["schema_seq", "user_id_seq"],
|
||||
)
|
||||
|
||||
|
||||
class HasSequenceTestEmpty(fixtures.TestBase):
|
||||
__requires__ = ("sequences",)
|
||||
__backend__ = True
|
||||
|
||||
def test_get_sequence_names_no_sequence(self, connection):
|
||||
eq_(
|
||||
inspect(connection).get_sequence_names(),
|
||||
[],
|
||||
)
|
||||
1508
lib/sqlalchemy/testing/suite/test_types.py
Normal file
1508
lib/sqlalchemy/testing/suite/test_types.py
Normal file
File diff suppressed because it is too large
Load Diff
206
lib/sqlalchemy/testing/suite/test_unicode_ddl.py
Normal file
206
lib/sqlalchemy/testing/suite/test_unicode_ddl.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# coding: utf-8
|
||||
"""verrrrry basic unicode column name testing"""
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy import util
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing.schema import Column
|
||||
from sqlalchemy.testing.schema import Table
|
||||
from sqlalchemy.util import u
|
||||
from sqlalchemy.util import ue
|
||||
|
||||
|
||||
class UnicodeSchemaTest(fixtures.TablesTest):
|
||||
__requires__ = ("unicode_ddl",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
global t1, t2, t3
|
||||
|
||||
t1 = Table(
|
||||
u("unitable1"),
|
||||
metadata,
|
||||
Column(u("méil"), Integer, primary_key=True),
|
||||
Column(ue("\u6e2c\u8a66"), Integer),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
t2 = Table(
|
||||
u("Unitéble2"),
|
||||
metadata,
|
||||
Column(u("méil"), Integer, primary_key=True, key="a"),
|
||||
Column(
|
||||
ue("\u6e2c\u8a66"),
|
||||
Integer,
|
||||
ForeignKey(u("unitable1.méil")),
|
||||
key="b",
|
||||
),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
# Few DBs support Unicode foreign keys
|
||||
if testing.against("sqlite"):
|
||||
t3 = Table(
|
||||
ue("\u6e2c\u8a66"),
|
||||
metadata,
|
||||
Column(
|
||||
ue("\u6e2c\u8a66_id"),
|
||||
Integer,
|
||||
primary_key=True,
|
||||
autoincrement=False,
|
||||
),
|
||||
Column(
|
||||
ue("unitable1_\u6e2c\u8a66"),
|
||||
Integer,
|
||||
ForeignKey(ue("unitable1.\u6e2c\u8a66")),
|
||||
),
|
||||
Column(
|
||||
u("Unitéble2_b"), Integer, ForeignKey(u("Unitéble2.b"))
|
||||
),
|
||||
Column(
|
||||
ue("\u6e2c\u8a66_self"),
|
||||
Integer,
|
||||
ForeignKey(ue("\u6e2c\u8a66.\u6e2c\u8a66_id")),
|
||||
),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
else:
|
||||
t3 = Table(
|
||||
ue("\u6e2c\u8a66"),
|
||||
metadata,
|
||||
Column(
|
||||
ue("\u6e2c\u8a66_id"),
|
||||
Integer,
|
||||
primary_key=True,
|
||||
autoincrement=False,
|
||||
),
|
||||
Column(ue("unitable1_\u6e2c\u8a66"), Integer),
|
||||
Column(u("Unitéble2_b"), Integer),
|
||||
Column(ue("\u6e2c\u8a66_self"), Integer),
|
||||
test_needs_fk=True,
|
||||
)
|
||||
|
||||
def test_insert(self, connection):
|
||||
connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
|
||||
connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
ue("\u6e2c\u8a66_id"): 1,
|
||||
ue("unitable1_\u6e2c\u8a66"): 5,
|
||||
u("Unitéble2_b"): 1,
|
||||
ue("\u6e2c\u8a66_self"): 1,
|
||||
},
|
||||
)
|
||||
|
||||
eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
|
||||
eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
|
||||
eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
|
||||
|
||||
def test_col_targeting(self, connection):
|
||||
connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
|
||||
connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
ue("\u6e2c\u8a66_id"): 1,
|
||||
ue("unitable1_\u6e2c\u8a66"): 5,
|
||||
u("Unitéble2_b"): 1,
|
||||
ue("\u6e2c\u8a66_self"): 1,
|
||||
},
|
||||
)
|
||||
|
||||
row = connection.execute(t1.select()).first()
|
||||
eq_(row._mapping[t1.c[u("méil")]], 1)
|
||||
eq_(row._mapping[t1.c[ue("\u6e2c\u8a66")]], 5)
|
||||
|
||||
row = connection.execute(t2.select()).first()
|
||||
eq_(row._mapping[t2.c[u("a")]], 1)
|
||||
eq_(row._mapping[t2.c[u("b")]], 1)
|
||||
|
||||
row = connection.execute(t3.select()).first()
|
||||
eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_id")]], 1)
|
||||
eq_(row._mapping[t3.c[ue("unitable1_\u6e2c\u8a66")]], 5)
|
||||
eq_(row._mapping[t3.c[u("Unitéble2_b")]], 1)
|
||||
eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_self")]], 1)
|
||||
|
||||
def test_reflect(self, connection):
|
||||
connection.execute(t1.insert(), {u("méil"): 2, ue("\u6e2c\u8a66"): 7})
|
||||
connection.execute(t2.insert(), {u("a"): 2, u("b"): 2})
|
||||
connection.execute(
|
||||
t3.insert(),
|
||||
{
|
||||
ue("\u6e2c\u8a66_id"): 2,
|
||||
ue("unitable1_\u6e2c\u8a66"): 7,
|
||||
u("Unitéble2_b"): 2,
|
||||
ue("\u6e2c\u8a66_self"): 2,
|
||||
},
|
||||
)
|
||||
|
||||
meta = MetaData()
|
||||
tt1 = Table(t1.name, meta, autoload_with=connection)
|
||||
tt2 = Table(t2.name, meta, autoload_with=connection)
|
||||
tt3 = Table(t3.name, meta, autoload_with=connection)
|
||||
|
||||
connection.execute(tt1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
|
||||
connection.execute(tt2.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 1})
|
||||
connection.execute(
|
||||
tt3.insert(),
|
||||
{
|
||||
ue("\u6e2c\u8a66_id"): 1,
|
||||
ue("unitable1_\u6e2c\u8a66"): 5,
|
||||
u("Unitéble2_b"): 1,
|
||||
ue("\u6e2c\u8a66_self"): 1,
|
||||
},
|
||||
)
|
||||
|
||||
eq_(
|
||||
connection.execute(
|
||||
tt1.select().order_by(desc(u("méil")))
|
||||
).fetchall(),
|
||||
[(2, 7), (1, 5)],
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
tt2.select().order_by(desc(u("méil")))
|
||||
).fetchall(),
|
||||
[(2, 2), (1, 1)],
|
||||
)
|
||||
eq_(
|
||||
connection.execute(
|
||||
tt3.select().order_by(desc(ue("\u6e2c\u8a66_id")))
|
||||
).fetchall(),
|
||||
[(2, 7, 2, 2), (1, 5, 1, 1)],
|
||||
)
|
||||
|
||||
def test_repr(self):
|
||||
meta = MetaData()
|
||||
t = Table(
|
||||
ue("\u6e2c\u8a66"), meta, Column(ue("\u6e2c\u8a66_id"), Integer)
|
||||
)
|
||||
|
||||
if util.py2k:
|
||||
eq_(
|
||||
repr(t),
|
||||
(
|
||||
"Table('\\u6e2c\\u8a66', MetaData(), "
|
||||
"Column('\\u6e2c\\u8a66_id', Integer(), "
|
||||
"table=<\u6e2c\u8a66>), "
|
||||
"schema=None)"
|
||||
),
|
||||
)
|
||||
else:
|
||||
eq_(
|
||||
repr(t),
|
||||
(
|
||||
"Table('測試', MetaData(), "
|
||||
"Column('測試_id', Integer(), "
|
||||
"table=<測試>), "
|
||||
"schema=None)"
|
||||
),
|
||||
)
|
||||
60
lib/sqlalchemy/testing/suite/test_update_delete.py
Normal file
60
lib/sqlalchemy/testing/suite/test_update_delete.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from .. import fixtures
|
||||
from ..assertions import eq_
|
||||
from ..schema import Column
|
||||
from ..schema import Table
|
||||
from ... import Integer
|
||||
from ... import String
|
||||
|
||||
|
||||
class SimpleUpdateDeleteTest(fixtures.TablesTest):
|
||||
run_deletes = "each"
|
||||
__requires__ = ("sane_rowcount",)
|
||||
__backend__ = True
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
Table(
|
||||
"plain_pk",
|
||||
metadata,
|
||||
Column("id", Integer, primary_key=True),
|
||||
Column("data", String(50)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_data(cls, connection):
|
||||
connection.execute(
|
||||
cls.tables.plain_pk.insert(),
|
||||
[
|
||||
{"id": 1, "data": "d1"},
|
||||
{"id": 2, "data": "d2"},
|
||||
{"id": 3, "data": "d3"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_update(self, connection):
|
||||
t = self.tables.plain_pk
|
||||
r = connection.execute(
|
||||
t.update().where(t.c.id == 2), dict(data="d2_new")
|
||||
)
|
||||
assert not r.is_insert
|
||||
assert not r.returns_rows
|
||||
assert r.rowcount == 1
|
||||
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
[(1, "d1"), (2, "d2_new"), (3, "d3")],
|
||||
)
|
||||
|
||||
def test_delete(self, connection):
|
||||
t = self.tables.plain_pk
|
||||
r = connection.execute(t.delete().where(t.c.id == 2))
|
||||
assert not r.is_insert
|
||||
assert not r.returns_rows
|
||||
assert r.rowcount == 1
|
||||
eq_(
|
||||
connection.execute(t.select().order_by(t.c.id)).fetchall(),
|
||||
[(1, "d1"), (3, "d3")],
|
||||
)
|
||||
|
||||
|
||||
__all__ = ("SimpleUpdateDeleteTest",)
|
||||
458
lib/sqlalchemy/testing/util.py
Normal file
458
lib/sqlalchemy/testing/util.py
Normal file
@@ -0,0 +1,458 @@
|
||||
# testing/util.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
import decimal
|
||||
import gc
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
|
||||
from . import config
|
||||
from . import mock
|
||||
from .. import inspect
|
||||
from ..engine import Connection
|
||||
from ..schema import Column
|
||||
from ..schema import DropConstraint
|
||||
from ..schema import DropTable
|
||||
from ..schema import ForeignKeyConstraint
|
||||
from ..schema import MetaData
|
||||
from ..schema import Table
|
||||
from ..sql import schema
|
||||
from ..sql.sqltypes import Integer
|
||||
from ..util import decorator
|
||||
from ..util import defaultdict
|
||||
from ..util import has_refcount_gc
|
||||
from ..util import inspect_getfullargspec
|
||||
from ..util import py2k
|
||||
|
||||
|
||||
if not has_refcount_gc:
|
||||
|
||||
def non_refcount_gc_collect(*args):
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
|
||||
gc_collect = lazy_gc = non_refcount_gc_collect
|
||||
else:
|
||||
# assume CPython - straight gc.collect, lazy_gc() is a pass
|
||||
gc_collect = gc.collect
|
||||
|
||||
def lazy_gc():
|
||||
pass
|
||||
|
||||
|
||||
def picklers():
|
||||
picklers = set()
|
||||
if py2k:
|
||||
try:
|
||||
import cPickle
|
||||
|
||||
picklers.add(cPickle)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import pickle
|
||||
|
||||
picklers.add(pickle)
|
||||
|
||||
# yes, this thing needs this much testing
|
||||
for pickle_ in picklers:
|
||||
for protocol in range(-2, pickle.HIGHEST_PROTOCOL):
|
||||
yield pickle_.loads, lambda d: pickle_.dumps(d, protocol)
|
||||
|
||||
|
||||
if py2k:
|
||||
|
||||
def random_choices(population, k=1):
|
||||
pop = list(population)
|
||||
# lame but works :)
|
||||
random.shuffle(pop)
|
||||
return pop[0:k]
|
||||
|
||||
|
||||
else:
|
||||
|
||||
def random_choices(population, k=1):
|
||||
return random.choices(population, k=k)
|
||||
|
||||
|
||||
def round_decimal(value, prec):
|
||||
if isinstance(value, float):
|
||||
return round(value, prec)
|
||||
|
||||
# can also use shift() here but that is 2.6 only
|
||||
return (value * decimal.Decimal("1" + "0" * prec)).to_integral(
|
||||
decimal.ROUND_FLOOR
|
||||
) / pow(10, prec)
|
||||
|
||||
|
||||
class RandomSet(set):
|
||||
def __iter__(self):
|
||||
l = list(set.__iter__(self))
|
||||
random.shuffle(l)
|
||||
return iter(l)
|
||||
|
||||
def pop(self):
|
||||
index = random.randint(0, len(self) - 1)
|
||||
item = list(set.__iter__(self))[index]
|
||||
self.remove(item)
|
||||
return item
|
||||
|
||||
def union(self, other):
|
||||
return RandomSet(set.union(self, other))
|
||||
|
||||
def difference(self, other):
|
||||
return RandomSet(set.difference(self, other))
|
||||
|
||||
def intersection(self, other):
|
||||
return RandomSet(set.intersection(self, other))
|
||||
|
||||
def copy(self):
|
||||
return RandomSet(self)
|
||||
|
||||
|
||||
def conforms_partial_ordering(tuples, sorted_elements):
|
||||
"""True if the given sorting conforms to the given partial ordering."""
|
||||
|
||||
deps = defaultdict(set)
|
||||
for parent, child in tuples:
|
||||
deps[parent].add(child)
|
||||
for i, node in enumerate(sorted_elements):
|
||||
for n in sorted_elements[i:]:
|
||||
if node in deps[n]:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def all_partial_orderings(tuples, elements):
|
||||
edges = defaultdict(set)
|
||||
for parent, child in tuples:
|
||||
edges[child].add(parent)
|
||||
|
||||
def _all_orderings(elements):
|
||||
|
||||
if len(elements) == 1:
|
||||
yield list(elements)
|
||||
else:
|
||||
for elem in elements:
|
||||
subset = set(elements).difference([elem])
|
||||
if not subset.intersection(edges[elem]):
|
||||
for sub_ordering in _all_orderings(subset):
|
||||
yield [elem] + sub_ordering
|
||||
|
||||
return iter(_all_orderings(elements))
|
||||
|
||||
|
||||
def function_named(fn, name):
|
||||
"""Return a function with a given __name__.
|
||||
|
||||
Will assign to __name__ and return the original function if possible on
|
||||
the Python implementation, otherwise a new function will be constructed.
|
||||
|
||||
This function should be phased out as much as possible
|
||||
in favor of @decorator. Tests that "generate" many named tests
|
||||
should be modernized.
|
||||
|
||||
"""
|
||||
try:
|
||||
fn.__name__ = name
|
||||
except TypeError:
|
||||
fn = types.FunctionType(
|
||||
fn.__code__, fn.__globals__, name, fn.__defaults__, fn.__closure__
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def run_as_contextmanager(ctx, fn, *arg, **kw):
|
||||
"""Run the given function under the given contextmanager,
|
||||
simulating the behavior of 'with' to support older
|
||||
Python versions.
|
||||
|
||||
This is not necessary anymore as we have placed 2.6
|
||||
as minimum Python version, however some tests are still using
|
||||
this structure.
|
||||
|
||||
"""
|
||||
|
||||
obj = ctx.__enter__()
|
||||
try:
|
||||
result = fn(obj, *arg, **kw)
|
||||
ctx.__exit__(None, None, None)
|
||||
return result
|
||||
except:
|
||||
exc_info = sys.exc_info()
|
||||
raise_ = ctx.__exit__(*exc_info)
|
||||
if not raise_:
|
||||
raise
|
||||
else:
|
||||
return raise_
|
||||
|
||||
|
||||
def rowset(results):
|
||||
"""Converts the results of sql execution into a plain set of column tuples.
|
||||
|
||||
Useful for asserting the results of an unordered query.
|
||||
"""
|
||||
|
||||
return {tuple(row) for row in results}
|
||||
|
||||
|
||||
def fail(msg):
|
||||
assert False, msg
|
||||
|
||||
|
||||
@decorator
|
||||
def provide_metadata(fn, *args, **kw):
|
||||
"""Provide bound MetaData for a single test, dropping afterwards.
|
||||
|
||||
Legacy; use the "metadata" pytest fixture.
|
||||
|
||||
"""
|
||||
|
||||
from . import fixtures
|
||||
|
||||
metadata = schema.MetaData()
|
||||
self = args[0]
|
||||
prev_meta = getattr(self, "metadata", None)
|
||||
self.metadata = metadata
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
# close out some things that get in the way of dropping tables.
|
||||
# when using the "metadata" fixture, there is a set ordering
|
||||
# of things that makes sure things are cleaned up in order, however
|
||||
# the simple "decorator" nature of this legacy function means
|
||||
# we have to hardcode some of that cleanup ahead of time.
|
||||
|
||||
# close ORM sessions
|
||||
fixtures._close_all_sessions()
|
||||
|
||||
# integrate with the "connection" fixture as there are many
|
||||
# tests where it is used along with provide_metadata
|
||||
if fixtures._connection_fixture_connection:
|
||||
# TODO: this warning can be used to find all the places
|
||||
# this is used with connection fixture
|
||||
# warn("mixing legacy provide metadata with connection fixture")
|
||||
drop_all_tables_from_metadata(
|
||||
metadata, fixtures._connection_fixture_connection
|
||||
)
|
||||
# as the provide_metadata fixture is often used with "testing.db",
|
||||
# when we do the drop we have to commit the transaction so that
|
||||
# the DB is actually updated as the CREATE would have been
|
||||
# committed
|
||||
fixtures._connection_fixture_connection.get_transaction().commit()
|
||||
else:
|
||||
drop_all_tables_from_metadata(metadata, config.db)
|
||||
self.metadata = prev_meta
|
||||
|
||||
|
||||
def flag_combinations(*combinations):
|
||||
"""A facade around @testing.combinations() oriented towards boolean
|
||||
keyword-based arguments.
|
||||
|
||||
Basically generates a nice looking identifier based on the keywords
|
||||
and also sets up the argument names.
|
||||
|
||||
E.g.::
|
||||
|
||||
@testing.flag_combinations(
|
||||
dict(lazy=False, passive=False),
|
||||
dict(lazy=True, passive=False),
|
||||
dict(lazy=False, passive=True),
|
||||
dict(lazy=False, passive=True, raiseload=True),
|
||||
)
|
||||
|
||||
|
||||
would result in::
|
||||
|
||||
@testing.combinations(
|
||||
('', False, False, False),
|
||||
('lazy', True, False, False),
|
||||
('lazy_passive', True, True, False),
|
||||
('lazy_passive', True, True, True),
|
||||
id_='iaaa',
|
||||
argnames='lazy,passive,raiseload'
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
keys = set()
|
||||
|
||||
for d in combinations:
|
||||
keys.update(d)
|
||||
|
||||
keys = sorted(keys)
|
||||
|
||||
return config.combinations(
|
||||
*[
|
||||
("_".join(k for k in keys if d.get(k, False)),)
|
||||
+ tuple(d.get(k, False) for k in keys)
|
||||
for d in combinations
|
||||
],
|
||||
id_="i" + ("a" * len(keys)),
|
||||
argnames=",".join(keys)
|
||||
)
|
||||
|
||||
|
||||
def lambda_combinations(lambda_arg_sets, **kw):
|
||||
args = inspect_getfullargspec(lambda_arg_sets)
|
||||
|
||||
arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]])
|
||||
|
||||
def create_fixture(pos):
|
||||
def fixture(**kw):
|
||||
return lambda_arg_sets(**kw)[pos]
|
||||
|
||||
fixture.__name__ = "fixture_%3.3d" % pos
|
||||
return fixture
|
||||
|
||||
return config.combinations(
|
||||
*[(create_fixture(i),) for i in range(len(arg_sets))], **kw
|
||||
)
|
||||
|
||||
|
||||
def resolve_lambda(__fn, **kw):
|
||||
"""Given a no-arg lambda and a namespace, return a new lambda that
|
||||
has all the values filled in.
|
||||
|
||||
This is used so that we can have module-level fixtures that
|
||||
refer to instance-level variables using lambdas.
|
||||
|
||||
"""
|
||||
|
||||
pos_args = inspect_getfullargspec(__fn)[0]
|
||||
pass_pos_args = {arg: kw.pop(arg) for arg in pos_args}
|
||||
glb = dict(__fn.__globals__)
|
||||
glb.update(kw)
|
||||
new_fn = types.FunctionType(__fn.__code__, glb)
|
||||
return new_fn(**pass_pos_args)
|
||||
|
||||
|
||||
def metadata_fixture(ddl="function"):
|
||||
"""Provide MetaData for a pytest fixture."""
|
||||
|
||||
def decorate(fn):
|
||||
def run_ddl(self):
|
||||
|
||||
metadata = self.metadata = schema.MetaData()
|
||||
try:
|
||||
result = fn(self, metadata)
|
||||
metadata.create_all(config.db)
|
||||
# TODO:
|
||||
# somehow get a per-function dml erase fixture here
|
||||
yield result
|
||||
finally:
|
||||
metadata.drop_all(config.db)
|
||||
|
||||
return config.fixture(scope=ddl)(run_ddl)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def force_drop_names(*names):
|
||||
"""Force the given table names to be dropped after test complete,
|
||||
isolating for foreign key cycles
|
||||
|
||||
"""
|
||||
|
||||
@decorator
|
||||
def go(fn, *args, **kw):
|
||||
|
||||
try:
|
||||
return fn(*args, **kw)
|
||||
finally:
|
||||
drop_all_tables(config.db, inspect(config.db), include_names=names)
|
||||
|
||||
return go
|
||||
|
||||
|
||||
class adict(dict):
|
||||
"""Dict keys available as attributes. Shadows."""
|
||||
|
||||
def __getattribute__(self, key):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return dict.__getattribute__(self, key)
|
||||
|
||||
def __call__(self, *keys):
|
||||
return tuple([self[key] for key in keys])
|
||||
|
||||
get_all = __call__
|
||||
|
||||
|
||||
def drop_all_tables_from_metadata(metadata, engine_or_connection):
|
||||
from . import engines
|
||||
|
||||
def go(connection):
|
||||
engines.testing_reaper.prepare_for_drop_tables(connection)
|
||||
|
||||
if not connection.dialect.supports_alter:
|
||||
from . import assertions
|
||||
|
||||
with assertions.expect_warnings(
|
||||
"Can't sort tables", assert_=False
|
||||
):
|
||||
metadata.drop_all(connection)
|
||||
else:
|
||||
metadata.drop_all(connection)
|
||||
|
||||
if not isinstance(engine_or_connection, Connection):
|
||||
with engine_or_connection.begin() as connection:
|
||||
go(connection)
|
||||
else:
|
||||
go(engine_or_connection)
|
||||
|
||||
|
||||
def drop_all_tables(engine, inspector, schema=None, include_names=None):
|
||||
|
||||
if include_names is not None:
|
||||
include_names = set(include_names)
|
||||
|
||||
with engine.begin() as conn:
|
||||
for tname, fkcs in reversed(
|
||||
inspector.get_sorted_table_and_fkc_names(schema=schema)
|
||||
):
|
||||
if tname:
|
||||
if include_names is not None and tname not in include_names:
|
||||
continue
|
||||
conn.execute(
|
||||
DropTable(Table(tname, MetaData(), schema=schema))
|
||||
)
|
||||
elif fkcs:
|
||||
if not engine.dialect.supports_alter:
|
||||
continue
|
||||
for tname, fkc in fkcs:
|
||||
if (
|
||||
include_names is not None
|
||||
and tname not in include_names
|
||||
):
|
||||
continue
|
||||
tb = Table(
|
||||
tname,
|
||||
MetaData(),
|
||||
Column("x", Integer),
|
||||
Column("y", Integer),
|
||||
schema=schema,
|
||||
)
|
||||
conn.execute(
|
||||
DropConstraint(
|
||||
ForeignKeyConstraint([tb.c.x], [tb.c.y], name=fkc)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def teardown_events(event_cls):
|
||||
@decorator
|
||||
def decorate(fn, *arg, **kw):
|
||||
try:
|
||||
return fn(*arg, **kw)
|
||||
finally:
|
||||
event_cls._clear()
|
||||
|
||||
return decorate
|
||||
82
lib/sqlalchemy/testing/warnings.py
Normal file
82
lib/sqlalchemy/testing/warnings.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# testing/warnings.py
|
||||
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import warnings
|
||||
|
||||
from . import assertions
|
||||
from .. import exc as sa_exc
|
||||
from ..util.langhelpers import _warnings_warn
|
||||
|
||||
|
||||
class SATestSuiteWarning(Warning):
|
||||
"""warning for a condition detected during tests that is non-fatal
|
||||
|
||||
Currently outside of SAWarning so that we can work around tools like
|
||||
Alembic doing the wrong thing with warnings.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def warn_test_suite(message):
|
||||
_warnings_warn(message, category=SATestSuiteWarning)
|
||||
|
||||
|
||||
def setup_filters():
|
||||
"""Set global warning behavior for the test suite."""
|
||||
|
||||
# TODO: at this point we can use the normal pytest warnings plugin,
|
||||
# if we decide the test suite can be linked to pytest only
|
||||
|
||||
origin = r"^(?:test|sqlalchemy)\..*"
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", category=sa_exc.SAPendingDeprecationWarning
|
||||
)
|
||||
warnings.filterwarnings("error", category=sa_exc.SADeprecationWarning)
|
||||
warnings.filterwarnings("error", category=sa_exc.SAWarning)
|
||||
|
||||
warnings.filterwarnings("always", category=SATestSuiteWarning)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"error", category=DeprecationWarning, module=origin
|
||||
)
|
||||
|
||||
# ignore things that are deprecated *as of* 2.0 :)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=sa_exc.SADeprecationWarning,
|
||||
message=r".*\(deprecated since: 2.0\)$",
|
||||
)
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=sa_exc.SADeprecationWarning,
|
||||
message=r"^The (Sybase|firebird) dialect is deprecated and will be",
|
||||
)
|
||||
|
||||
try:
|
||||
import pytest
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
warnings.filterwarnings(
|
||||
"once", category=pytest.PytestDeprecationWarning, module=origin
|
||||
)
|
||||
|
||||
|
||||
def assert_warnings(fn, warning_msgs, regex=False):
|
||||
"""Assert that each of the given warnings are emitted by fn.
|
||||
|
||||
Deprecated. Please use assertions.expect_warnings().
|
||||
|
||||
"""
|
||||
|
||||
with assertions._expect_warnings(
|
||||
sa_exc.SAWarning, warning_msgs, regex=regex
|
||||
):
|
||||
return fn()
|
||||
Reference in New Issue
Block a user