summaryrefslogtreecommitdiffstats
path: root/lib/sqlalchemy/testing/fixtures.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/testing/fixtures.py')
-rw-r--r--lib/sqlalchemy/testing/fixtures.py870
1 files changed, 870 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py
new file mode 100644
index 0000000..0a2d63b
--- /dev/null
+++ b/lib/sqlalchemy/testing/fixtures.py
@@ -0,0 +1,870 @@
+# testing/fixtures.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import contextlib
+import re
+import sys
+
+import sqlalchemy as sa
+from . import assertions
+from . import config
+from . import schema
+from .entities import BasicEntity
+from .entities import ComparableEntity
+from .entities import ComparableMixin # noqa
+from .util import adict
+from .util import drop_all_tables_from_metadata
+from .. import event
+from .. import util
+from ..orm import declarative_base
+from ..orm import registry
+from ..orm.decl_api import DeclarativeMeta
+from ..schema import sort_tables_and_constraints
+
+
+@config.mark_base_test_class()
+class TestBase(object):
+ # A sequence of requirement names matching testing.requires decorators
+ __requires__ = ()
+
+ # A sequence of dialect names to exclude from the test class.
+ __unsupported_on__ = ()
+
+ # If present, test class is only runnable for the *single* specified
+ # dialect. If you need multiple, use __unsupported_on__ and invert.
+ __only_on__ = None
+
+ # A sequence of no-arg callables. If any are True, the entire testcase is
+ # skipped.
+ __skip_if__ = None
+
+ # if True, the testing reaper will not attempt to touch connection
+ # state after a test is completed and before the outer teardown
+ # starts
+ __leave_connections_for_teardown__ = False
+
+ def assert_(self, val, msg=None):
+ assert val, msg
+
+ @config.fixture()
+ def nocache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+ @config.fixture()
+ def connection_no_trans(self):
+ eng = getattr(self, "bind", None) or config.db
+
+ with eng.connect() as conn:
+ yield conn
+
+ @config.fixture()
+ def connection(self):
+ global _connection_fixture_connection
+
+ eng = getattr(self, "bind", None) or config.db
+
+ conn = eng.connect()
+ trans = conn.begin()
+
+ _connection_fixture_connection = conn
+ yield conn
+
+ _connection_fixture_connection = None
+
+ if trans.is_active:
+ trans.rollback()
+ # trans would not be active here if the test is using
+ # the legacy @provide_metadata decorator still, as it will
+ # run a close all connections.
+ conn.close()
+
+ @config.fixture()
+ def close_result_when_finished(self):
+ to_close = []
+ to_consume = []
+
+ def go(result, consume=False):
+ to_close.append(result)
+ if consume:
+ to_consume.append(result)
+
+ yield go
+ for r in to_consume:
+ try:
+ r.all()
+ except:
+ pass
+ for r in to_close:
+ try:
+ r.close()
+ except:
+ pass
+
+ @config.fixture()
+ def registry(self, metadata):
+ reg = registry(metadata=metadata)
+ yield reg
+ reg.dispose()
+
+ @config.fixture
+ def decl_base(self, registry):
+ return registry.generate_base()
+
+ @config.fixture()
+ def future_connection(self, future_engine, connection):
+ # integrate the future_engine and connection fixtures so
+ # that users of the "connection" fixture will get at the
+ # "future" connection
+ yield connection
+
+ @config.fixture()
+ def future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+ @config.fixture()
+ def testing_engine(self):
+ from . import engines
+
+ def gen_testing_engine(
+ url=None,
+ options=None,
+ future=None,
+ asyncio=False,
+ transfer_staticpool=False,
+ ):
+ if options is None:
+ options = {}
+ options["scope"] = "fixture"
+ return engines.testing_engine(
+ url=url,
+ options=options,
+ future=future,
+ asyncio=asyncio,
+ transfer_staticpool=transfer_staticpool,
+ )
+
+ yield gen_testing_engine
+
+ engines.testing_reaper._drop_testing_engines("fixture")
+
+ @config.fixture()
+ def async_testing_engine(self, testing_engine):
+ def go(**kw):
+ kw["asyncio"] = True
+ return testing_engine(**kw)
+
+ return go
+
+ @config.fixture
+ def fixture_session(self):
+ return fixture_session()
+
+ @config.fixture()
+ def metadata(self, request):
+ """Provide bound MetaData for a single test, dropping afterwards."""
+
+ from ..sql import schema
+
+ metadata = schema.MetaData()
+ request.instance.metadata = metadata
+ yield metadata
+ del request.instance.metadata
+
+ if (
+ _connection_fixture_connection
+ and _connection_fixture_connection.in_transaction()
+ ):
+ trans = _connection_fixture_connection.get_transaction()
+ trans.rollback()
+ with _connection_fixture_connection.begin():
+ drop_all_tables_from_metadata(
+ metadata, _connection_fixture_connection
+ )
+ else:
+ drop_all_tables_from_metadata(metadata, config.db)
+
+ @config.fixture(
+ params=[
+ (rollback, second_operation, begin_nested)
+ for rollback in (True, False)
+ for second_operation in ("none", "execute", "begin")
+ for begin_nested in (
+ True,
+ False,
+ )
+ ]
+ )
+ def trans_ctx_manager_fixture(self, request, metadata):
+ rollback, second_operation, begin_nested = request.param
+
+ from sqlalchemy import Table, Column, Integer, func, select
+ from . import eq_
+
+ t = Table("test", metadata, Column("data", Integer))
+ eng = getattr(self, "bind", None) or config.db
+
+ t.create(eng)
+
+ def run_test(subject, trans_on_subject, execute_on_subject):
+ with subject.begin() as trans:
+
+ if begin_nested:
+ if not config.requirements.savepoints.enabled:
+ config.skip_test("savepoints not enabled")
+ if execute_on_subject:
+ nested_trans = subject.begin_nested()
+ else:
+ nested_trans = trans.begin_nested()
+
+ with nested_trans:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ # for nested trans, we always commit/rollback on the
+ # "nested trans" object itself.
+ # only Session(future=False) will affect savepoint
+ # transaction for session.commit/rollback
+
+ if rollback:
+ nested_trans.rollback()
+ else:
+ nested_trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction "
+ "inside context "
+ "manager. Please complete the context "
+ "manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(
+ t.insert(), {"data": 12}
+ )
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ # outside the nested trans block, but still inside the
+ # transaction block, we can run SQL, and it will be
+ # committed
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 14})
+ else:
+ trans.execute(t.insert(), {"data": 14})
+
+ else:
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 10})
+ else:
+ trans.execute(t.insert(), {"data": 10})
+
+ if trans_on_subject:
+ if rollback:
+ subject.rollback()
+ else:
+ subject.commit()
+ else:
+ if rollback:
+ trans.rollback()
+ else:
+ trans.commit()
+
+ if second_operation != "none":
+ with assertions.expect_raises_message(
+ sa.exc.InvalidRequestError,
+ "Can't operate on closed transaction inside "
+ "context "
+ "manager. Please complete the context manager "
+ "before emitting further commands.",
+ ):
+ if second_operation == "execute":
+ if execute_on_subject:
+ subject.execute(t.insert(), {"data": 12})
+ else:
+ trans.execute(t.insert(), {"data": 12})
+ elif second_operation == "begin":
+ if hasattr(trans, "begin"):
+ trans.begin()
+ else:
+ subject.begin()
+ elif second_operation == "begin_nested":
+ if execute_on_subject:
+ subject.begin_nested()
+ else:
+ trans.begin_nested()
+
+ expected_committed = 0
+ if begin_nested:
+ # begin_nested variant, we inserted a row after the nested
+ # block
+ expected_committed += 1
+ if not rollback:
+ # not rollback variant, our row inserted in the target
+ # block itself would be committed
+ expected_committed += 1
+
+ if execute_on_subject:
+ eq_(
+ subject.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+ else:
+ with subject.connect() as conn:
+ eq_(
+ conn.scalar(select(func.count()).select_from(t)),
+ expected_committed,
+ )
+
+ return run_test
+
+
+_connection_fixture_connection = None
+
+
+@contextlib.contextmanager
+def _push_future_engine(engine):
+
+ from ..future.engine import Engine
+ from sqlalchemy import testing
+
+ facade = Engine._future_facade(engine)
+ config._current.push_engine(facade, testing)
+
+ yield facade
+
+ config._current.pop(testing)
+
+
+class FutureEngineMixin(object):
+ @config.fixture(autouse=True, scope="class")
+ def _push_future_engine(self):
+ eng = getattr(self, "bind", None) or config.db
+ with _push_future_engine(eng):
+ yield
+
+
+class TablesTest(TestBase):
+
+ # 'once', None
+ run_setup_bind = "once"
+
+ # 'once', 'each', None
+ run_define_tables = "once"
+
+ # 'once', 'each', None
+ run_create_tables = "once"
+
+ # 'once', 'each', None
+ run_inserts = "each"
+
+ # 'each', None
+ run_deletes = "each"
+
+ # 'once', None
+ run_dispose_bind = None
+
+ bind = None
+ _tables_metadata = None
+ tables = None
+ other = None
+ sequences = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ cls._setup_once_tables()
+
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_inserts()
+
+ yield
+
+ self._teardown_each_tables()
+
+ @property
+ def tables_test_metadata(self):
+ return self._tables_metadata
+
+ @classmethod
+ def _init_class(cls):
+ if cls.run_define_tables == "each":
+ if cls.run_create_tables == "once":
+ cls.run_create_tables = "each"
+ assert cls.run_inserts in ("each", None)
+
+ cls.other = adict()
+ cls.tables = adict()
+ cls.sequences = adict()
+
+ cls.bind = cls.setup_bind()
+ cls._tables_metadata = sa.MetaData()
+
+ @classmethod
+ def _setup_once_inserts(cls):
+ if cls.run_inserts == "once":
+ cls._load_fixtures()
+ with cls.bind.begin() as conn:
+ cls.insert_data(conn)
+
+ @classmethod
+ def _setup_once_tables(cls):
+ if cls.run_define_tables == "once":
+ cls.define_tables(cls._tables_metadata)
+ if cls.run_create_tables == "once":
+ cls._tables_metadata.create_all(cls.bind)
+ cls.tables.update(cls._tables_metadata.tables)
+ cls.sequences.update(cls._tables_metadata._sequences)
+
+ def _setup_each_tables(self):
+ if self.run_define_tables == "each":
+ self.define_tables(self._tables_metadata)
+ if self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+ self.tables.update(self._tables_metadata.tables)
+ self.sequences.update(self._tables_metadata._sequences)
+ elif self.run_create_tables == "each":
+ self._tables_metadata.create_all(self.bind)
+
+ def _setup_each_inserts(self):
+ if self.run_inserts == "each":
+ self._load_fixtures()
+ with self.bind.begin() as conn:
+ self.insert_data(conn)
+
+ def _teardown_each_tables(self):
+ if self.run_define_tables == "each":
+ self.tables.clear()
+ if self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+ self._tables_metadata.clear()
+ elif self.run_create_tables == "each":
+ drop_all_tables_from_metadata(self._tables_metadata, self.bind)
+
+ savepoints = getattr(config.requirements, "savepoints", False)
+ if savepoints:
+ savepoints = savepoints.enabled
+
+ # no need to run deletes if tables are recreated on setup
+ if (
+ self.run_define_tables != "each"
+ and self.run_create_tables != "each"
+ and self.run_deletes == "each"
+ ):
+ with self.bind.begin() as conn:
+ for table in reversed(
+ [
+ t
+ for (t, fks) in sort_tables_and_constraints(
+ self._tables_metadata.tables.values()
+ )
+ if t is not None
+ ]
+ ):
+ try:
+ if savepoints:
+ with conn.begin_nested():
+ conn.execute(table.delete())
+ else:
+ conn.execute(table.delete())
+ except sa.exc.DBAPIError as ex:
+ util.print_(
+ ("Error emptying table %s: %r" % (table, ex)),
+ file=sys.stderr,
+ )
+
+ @classmethod
+ def _teardown_once_metadata_bind(cls):
+ if cls.run_create_tables:
+ drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
+
+ if cls.run_dispose_bind == "once":
+ cls.dispose_bind(cls.bind)
+
+ cls._tables_metadata.bind = None
+
+ if cls.run_setup_bind is not None:
+ cls.bind = None
+
+ @classmethod
+ def setup_bind(cls):
+ return config.db
+
+ @classmethod
+ def dispose_bind(cls, bind):
+ if hasattr(bind, "dispose"):
+ bind.dispose()
+ elif hasattr(bind, "close"):
+ bind.close()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ pass
+
+ @classmethod
+ def fixtures(cls):
+ return {}
+
+ @classmethod
+ def insert_data(cls, connection):
+ pass
+
+ def sql_count_(self, count, fn):
+ self.assert_sql_count(self.bind, fn, count)
+
+ def sql_eq_(self, callable_, statements):
+ self.assert_sql(self.bind, callable_, statements)
+
+ @classmethod
+ def _load_fixtures(cls):
+ """Insert rows as represented by the fixtures() method."""
+ headers, rows = {}, {}
+ for table, data in cls.fixtures().items():
+ if len(data) < 2:
+ continue
+ if isinstance(table, util.string_types):
+ table = cls.tables[table]
+ headers[table] = data[0]
+ rows[table] = data[1:]
+ for table, fks in sort_tables_and_constraints(
+ cls._tables_metadata.tables.values()
+ ):
+ if table is None:
+ continue
+ if table not in headers:
+ continue
+ with cls.bind.begin() as conn:
+ conn.execute(
+ table.insert(),
+ [
+ dict(zip(headers[table], column_values))
+ for column_values in rows[table]
+ ],
+ )
+
+
+class NoCache(object):
+ @config.fixture(autouse=True, scope="function")
+ def _disable_cache(self):
+ _cache = config.db._compiled_cache
+ config.db._compiled_cache = None
+ yield
+ config.db._compiled_cache = _cache
+
+
+class RemovesEvents(object):
+ @util.memoized_property
+ def _event_fns(self):
+ return set()
+
+ def event_listen(self, target, name, fn, **kw):
+ self._event_fns.add((target, name, fn))
+ event.listen(target, name, fn, **kw)
+
+ @config.fixture(autouse=True, scope="function")
+ def _remove_events(self):
+ yield
+ for key in self._event_fns:
+ event.remove(*key)
+
+
+_fixture_sessions = set()
+
+
+def fixture_session(**kw):
+ kw.setdefault("autoflush", True)
+ kw.setdefault("expire_on_commit", True)
+
+ bind = kw.pop("bind", config.db)
+
+ sess = sa.orm.Session(bind, **kw)
+ _fixture_sessions.add(sess)
+ return sess
+
+
+def _close_all_sessions():
+ # will close all still-referenced sessions
+ sa.orm.session.close_all_sessions()
+ _fixture_sessions.clear()
+
+
+def stop_test_class_inside_fixtures(cls):
+ _close_all_sessions()
+ sa.orm.clear_mappers()
+
+
+def after_test():
+ if _fixture_sessions:
+ _close_all_sessions()
+
+
+class ORMTest(TestBase):
+ pass
+
+
+class MappedTest(TablesTest, assertions.AssertsExecutionResults):
+ # 'once', 'each', None
+ run_setup_classes = "once"
+
+ # 'once', 'each', None
+ run_setup_mappers = "each"
+
+ classes = None
+
+ @config.fixture(autouse=True, scope="class")
+ def _setup_tables_test_class(self):
+ cls = self.__class__
+ cls._init_class()
+
+ if cls.classes is None:
+ cls.classes = adict()
+
+ cls._setup_once_tables()
+ cls._setup_once_classes()
+ cls._setup_once_mappers()
+ cls._setup_once_inserts()
+
+ yield
+
+ cls._teardown_once_class()
+ cls._teardown_once_metadata_bind()
+
+ @config.fixture(autouse=True, scope="function")
+ def _setup_tables_test_instance(self):
+ self._setup_each_tables()
+ self._setup_each_classes()
+ self._setup_each_mappers()
+ self._setup_each_inserts()
+
+ yield
+
+ sa.orm.session.close_all_sessions()
+ self._teardown_each_mappers()
+ self._teardown_each_classes()
+ self._teardown_each_tables()
+
+ @classmethod
+ def _teardown_once_class(cls):
+ cls.classes.clear()
+
+ @classmethod
+ def _setup_once_classes(cls):
+ if cls.run_setup_classes == "once":
+ cls._with_register_classes(cls.setup_classes)
+
+ @classmethod
+ def _setup_once_mappers(cls):
+ if cls.run_setup_mappers == "once":
+ cls.mapper_registry, cls.mapper = cls._generate_registry()
+ cls._with_register_classes(cls.setup_mappers)
+
+ def _setup_each_mappers(self):
+ if self.run_setup_mappers != "once":
+ (
+ self.__class__.mapper_registry,
+ self.__class__.mapper,
+ ) = self._generate_registry()
+
+ if self.run_setup_mappers == "each":
+ self._with_register_classes(self.setup_mappers)
+
+ def _setup_each_classes(self):
+ if self.run_setup_classes == "each":
+ self._with_register_classes(self.setup_classes)
+
+ @classmethod
+ def _generate_registry(cls):
+ decl = registry(metadata=cls._tables_metadata)
+ return decl, decl.map_imperatively
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ """Run a setup method, framing the operation with a Base class
+ that will catch new subclasses to be established within
+ the "classes" registry.
+
+ """
+ cls_registry = cls.classes
+
+ assert cls_registry is not None
+
+ class FindFixture(type):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ type.__init__(cls, classname, bases, dict_)
+
+ class _Base(util.with_metaclass(FindFixture, object)):
+ pass
+
+ class Basic(BasicEntity, _Base):
+ pass
+
+ class Comparable(ComparableEntity, _Base):
+ pass
+
+ cls.Basic = Basic
+ cls.Comparable = Comparable
+ fn()
+
+ def _teardown_each_mappers(self):
+ # some tests create mappers in the test bodies
+ # and will define setup_mappers as None -
+ # clear mappers in any case
+ if self.run_setup_mappers != "once":
+ sa.orm.clear_mappers()
+
+ def _teardown_each_classes(self):
+ if self.run_setup_classes != "once":
+ self.classes.clear()
+
+ @classmethod
+ def setup_classes(cls):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ pass
+
+
+class DeclarativeMappedTest(MappedTest):
+ run_setup_classes = "once"
+ run_setup_mappers = "once"
+
+ @classmethod
+ def _setup_once_tables(cls):
+ pass
+
+ @classmethod
+ def _with_register_classes(cls, fn):
+ cls_registry = cls.classes
+
+ class FindFixtureDeclarative(DeclarativeMeta):
+ def __init__(cls, classname, bases, dict_):
+ cls_registry[classname] = cls
+ DeclarativeMeta.__init__(cls, classname, bases, dict_)
+
+ class DeclarativeBasic(object):
+ __table_cls__ = schema.Table
+
+ _DeclBase = declarative_base(
+ metadata=cls._tables_metadata,
+ metaclass=FindFixtureDeclarative,
+ cls=DeclarativeBasic,
+ )
+
+ cls.DeclarativeBasic = _DeclBase
+
+ # sets up cls.Basic which is helpful for things like composite
+ # classes
+ super(DeclarativeMappedTest, cls)._with_register_classes(fn)
+
+ if cls._tables_metadata.tables and cls.run_create_tables:
+ cls._tables_metadata.create_all(config.db)
+
+
+class ComputedReflectionFixtureTest(TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("computed_columns", "table_reflection")
+
+ regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
+
+ def normalize(self, text):
+ return self.regexp.sub("", text).lower()
+
+ @classmethod
+ def define_tables(cls, metadata):
+ from .. import Integer
+ from .. import testing
+ from ..schema import Column
+ from ..schema import Computed
+ from ..schema import Table
+
+ Table(
+ "computed_default_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_col", Integer, Computed("normal + 42")),
+ Column("with_default", Integer, server_default="42"),
+ )
+
+ t = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal + 42")),
+ )
+
+ if testing.requires.schemas.enabled:
+ t2 = Table(
+ "computed_column_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("normal", Integer),
+ Column("computed_no_flag", Integer, Computed("normal / 42")),
+ schema=config.test_schema,
+ )
+
+ if testing.requires.computed_columns_virtual.enabled:
+ t.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal + 2", persisted=False),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_virtual",
+ Integer,
+ Computed("normal / 2", persisted=False),
+ )
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ t.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal - 42", persisted=True),
+ )
+ )
+ if testing.requires.schemas.enabled:
+ t2.append_column(
+ Column(
+ "computed_stored",
+ Integer,
+ Computed("normal * 42", persisted=True),
+ )
+ )