summaryrefslogtreecommitdiffstats
path: root/lib/sqlalchemy/testing/suite
diff options
context:
space:
mode:
authorxiubuzhe <xiubuzhe@sina.com>2023-10-08 20:59:00 +0800
committerxiubuzhe <xiubuzhe@sina.com>2023-10-08 20:59:00 +0800
commit1dac2263372df2b85db5d029a45721fa158a5c9d (patch)
tree0365f9c57df04178a726d7584ca6a6b955a7ce6a /lib/sqlalchemy/testing/suite
parentb494be364bb39e1de128ada7dc576a729d99907e (diff)
downloadsunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.gz
sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.bz2
sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.zip
first add files
Diffstat (limited to 'lib/sqlalchemy/testing/suite')
-rw-r--r--lib/sqlalchemy/testing/suite/__init__.py13
-rw-r--r--lib/sqlalchemy/testing/suite/test_cte.py204
-rw-r--r--lib/sqlalchemy/testing/suite/test_ddl.py381
-rw-r--r--lib/sqlalchemy/testing/suite/test_deprecations.py145
-rw-r--r--lib/sqlalchemy/testing/suite/test_dialect.py361
-rw-r--r--lib/sqlalchemy/testing/suite/test_insert.py367
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py1738
-rw-r--r--lib/sqlalchemy/testing/suite/test_results.py426
-rw-r--r--lib/sqlalchemy/testing/suite/test_rowcount.py165
-rw-r--r--lib/sqlalchemy/testing/suite/test_select.py1783
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py282
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py1508
-rw-r--r--lib/sqlalchemy/testing/suite/test_unicode_ddl.py206
-rw-r--r--lib/sqlalchemy/testing/suite/test_update_delete.py60
14 files changed, 7639 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py
new file mode 100644
index 0000000..30817e1
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/__init__.py
@@ -0,0 +1,13 @@
+from .test_cte import * # noqa
+from .test_ddl import * # noqa
+from .test_deprecations import * # noqa
+from .test_dialect import * # noqa
+from .test_insert import * # noqa
+from .test_reflection import * # noqa
+from .test_results import * # noqa
+from .test_rowcount import * # noqa
+from .test_select import * # noqa
+from .test_sequence import * # noqa
+from .test_types import * # noqa
+from .test_unicode_ddl import * # noqa
+from .test_update_delete import * # noqa
diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py
new file mode 100644
index 0000000..a94ee55
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_cte.py
@@ -0,0 +1,204 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import ForeignKey
+from ... import Integer
+from ... import select
+from ... import String
+from ... import testing
+
+
+class CTETest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("ctes",)
+
+ run_inserts = "each"
+ run_deletes = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", ForeignKey("some_table.id")),
+ )
+
+ Table(
+ "some_other_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ Column("parent_id", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "d1", "parent_id": None},
+ {"id": 2, "data": "d2", "parent_id": 1},
+ {"id": 3, "data": "d3", "parent_id": 1},
+ {"id": 4, "data": "d4", "parent_id": 3},
+ {"id": 5, "data": "d5", "parent_id": 3},
+ ],
+ )
+
+ def test_select_nonrecursive_round_trip(self, connection):
+ some_table = self.tables.some_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ result = connection.execute(
+ select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
+ )
+ eq_(result.fetchall(), [("d4",)])
+
+ def test_select_recursive_round_trip(self, connection):
+ some_table = self.tables.some_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte", recursive=True)
+ )
+
+ cte_alias = cte.alias("c1")
+ st1 = some_table.alias()
+ # note that SQL Server requires this to be UNION ALL,
+ # can't be UNION
+ cte = cte.union_all(
+ select(st1).where(st1.c.id == cte_alias.c.parent_id)
+ )
+ result = connection.execute(
+ select(cte.c.data)
+ .where(cte.c.data != "d2")
+ .order_by(cte.c.data.desc())
+ )
+ eq_(
+ result.fetchall(),
+ [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
+ )
+
+ def test_insert_from_select_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(cte)
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.update_from
+ def test_update_from_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.update()
+ .values(parent_id=5)
+ .where(some_other_table.c.data == cte.c.data)
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None),
+ (2, "d2", 5),
+ (3, "d3", 5),
+ (4, "d4", 5),
+ (5, "d5", 3),
+ ],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.delete_from
+ def test_delete_from_round_trip(self, connection):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(1, "d1", None), (5, "d5", 3)],
+ )
+
+ @testing.requires.ctes_with_update_delete
+ def test_delete_scalar_subq_round_trip(self, connection):
+
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ connection.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"], select(some_table)
+ )
+ )
+
+ cte = (
+ select(some_table)
+ .where(some_table.c.data.in_(["d2", "d3", "d4"]))
+ .cte("some_cte")
+ )
+ connection.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data
+ == select(cte.c.data)
+ .where(cte.c.id == some_other_table.c.id)
+ .scalar_subquery()
+ )
+ )
+ eq_(
+ connection.execute(
+ select(some_other_table).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(1, "d1", None), (5, "d5", 3)],
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py
new file mode 100644
index 0000000..b3fee55
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_ddl.py
@@ -0,0 +1,381 @@
+import random
+
+from . import testing
+from .. import config
+from .. import fixtures
+from .. import util
+from ..assertions import eq_
+from ..assertions import is_false
+from ..assertions import is_true
+from ..config import requirements
+from ..schema import Table
+from ... import CheckConstraint
+from ... import Column
+from ... import ForeignKeyConstraint
+from ... import Index
+from ... import inspect
+from ... import Integer
+from ... import schema
+from ... import String
+from ... import UniqueConstraint
+
+
+class TableDDLTest(fixtures.TestBase):
+ __backend__ = True
+
+ def _simple_fixture(self, schema=None):
+ return Table(
+ "test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ schema=schema,
+ )
+
+ def _underscore_fixture(self):
+ return Table(
+ "_test_table",
+ self.metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("_data", String(50)),
+ )
+
+ def _table_index_fixture(self, schema=None):
+ table = self._simple_fixture(schema=schema)
+ idx = Index("test_index", table.c.data)
+ return table, idx
+
+ def _simple_roundtrip(self, table):
+ with config.db.begin() as conn:
+ conn.execute(table.insert().values((1, "some data")))
+ result = conn.execute(table.select())
+ eq_(result.first(), (1, "some data"))
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_create_table(self):
+ table = self._simple_fixture()
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.create_table
+ @requirements.schemas
+ @util.provide_metadata
+ def test_create_table_schema(self):
+ table = self._simple_fixture(schema=config.test_schema)
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.drop_table
+ @util.provide_metadata
+ def test_drop_table(self):
+ table = self._simple_fixture()
+ table.create(config.db, checkfirst=False)
+ table.drop(config.db, checkfirst=False)
+
+ @requirements.create_table
+ @util.provide_metadata
+ def test_underscore_names(self):
+ table = self._underscore_fixture()
+ table.create(config.db, checkfirst=False)
+ self._simple_roundtrip(table)
+
+ @requirements.comment_reflection
+ @util.provide_metadata
+ def test_add_table_comment(self, connection):
+ table = self._simple_fixture()
+ table.create(connection, checkfirst=False)
+ table.comment = "a comment"
+ connection.execute(schema.SetTableComment(table))
+ eq_(
+ inspect(connection).get_table_comment("test_table"),
+ {"text": "a comment"},
+ )
+
+ @requirements.comment_reflection
+ @util.provide_metadata
+ def test_drop_table_comment(self, connection):
+ table = self._simple_fixture()
+ table.create(connection, checkfirst=False)
+ table.comment = "a comment"
+ connection.execute(schema.SetTableComment(table))
+ connection.execute(schema.DropTableComment(table))
+ eq_(
+ inspect(connection).get_table_comment("test_table"), {"text": None}
+ )
+
+ @requirements.table_ddl_if_exists
+ @util.provide_metadata
+ def test_create_table_if_not_exists(self, connection):
+ table = self._simple_fixture()
+
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+
+ is_true(inspect(connection).has_table("test_table"))
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+
+ @requirements.index_ddl_if_exists
+ @util.provide_metadata
+ def test_create_index_if_not_exists(self, connection):
+ table, idx = self._table_index_fixture()
+
+ connection.execute(schema.CreateTable(table, if_not_exists=True))
+ is_true(inspect(connection).has_table("test_table"))
+ is_false(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.CreateIndex(idx, if_not_exists=True))
+
+ is_true(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.CreateIndex(idx, if_not_exists=True))
+
+ @requirements.table_ddl_if_exists
+ @util.provide_metadata
+ def test_drop_table_if_exists(self, connection):
+ table = self._simple_fixture()
+
+ table.create(connection)
+
+ is_true(inspect(connection).has_table("test_table"))
+
+ connection.execute(schema.DropTable(table, if_exists=True))
+
+ is_false(inspect(connection).has_table("test_table"))
+
+ connection.execute(schema.DropTable(table, if_exists=True))
+
+ @requirements.index_ddl_if_exists
+ @util.provide_metadata
+ def test_drop_index_if_exists(self, connection):
+ table, idx = self._table_index_fixture()
+
+ table.create(connection)
+
+ is_true(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.DropIndex(idx, if_exists=True))
+
+ is_false(
+ "test_index"
+ in [
+ ix["name"]
+ for ix in inspect(connection).get_indexes("test_table")
+ ]
+ )
+
+ connection.execute(schema.DropIndex(idx, if_exists=True))
+
+
+class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest):
+ pass
+
+
+class LongNameBlowoutTest(fixtures.TestBase):
+ """test the creation of a variety of DDL structures and ensure
+ label length limits pass on backends
+
+ """
+
+ __backend__ = True
+
+ def fk(self, metadata, connection):
+ convention = {
+ "fk": "foreign_key_%(table_name)s_"
+ "%(column_0_N_name)s_"
+ "%(referred_table_name)s_"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(20))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ cons = ForeignKeyConstraint(
+ ["aid"], ["a_things_with_stuff.id_long_column_name"]
+ )
+ Table(
+ "b_related_things_of_value",
+ metadata,
+ Column(
+ "aid",
+ ),
+ cons,
+ test_needs_fk=True,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+
+ if testing.requires.foreign_key_constraint_name_reflection.enabled:
+ insp = inspect(connection)
+ fks = insp.get_foreign_keys("b_related_things_of_value")
+ reflected_name = fks[0]["name"]
+
+ return actual_name, reflected_name
+ else:
+ return actual_name, None
+
+ def pk(self, metadata, connection):
+ convention = {
+ "pk": "primary_key_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ a = Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer, primary_key=True),
+ )
+ cons = a.primary_key
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ pk = insp.get_pk_constraint("a_things_with_stuff")
+ reflected_name = pk["name"]
+ return actual_name, reflected_name
+
+ def ix(self, metadata, connection):
+ convention = {
+ "ix": "index_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ a = Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer),
+ )
+ cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name)
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ ix = insp.get_indexes("a_things_with_stuff")
+ reflected_name = ix[0]["name"]
+ return actual_name, reflected_name
+
+ def uq(self, metadata, connection):
+ convention = {
+ "uq": "unique_constraint_%(table_name)s_"
+ "%(column_0_N_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ cons = UniqueConstraint("id_long_column_name", "id_another_long_name")
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("id_another_long_name", Integer),
+ cons,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ uq = insp.get_unique_constraints("a_things_with_stuff")
+ reflected_name = uq[0]["name"]
+ return actual_name, reflected_name
+
+ def ck(self, metadata, connection):
+ convention = {
+ "ck": "check_constraint_%(table_name)s"
+ + (
+ "_".join(
+ "".join(random.choice("abcdef") for j in range(30))
+ for i in range(10)
+ )
+ ),
+ }
+ metadata.naming_convention = convention
+
+ cons = CheckConstraint("some_long_column_name > 5")
+ Table(
+ "a_things_with_stuff",
+ metadata,
+ Column("id_long_column_name", Integer, primary_key=True),
+ Column("some_long_column_name", Integer),
+ cons,
+ )
+ actual_name = cons.name
+
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ ck = insp.get_check_constraints("a_things_with_stuff")
+ reflected_name = ck[0]["name"]
+ return actual_name, reflected_name
+
+ @testing.combinations(
+ ("fk",),
+ ("pk",),
+ ("ix",),
+ ("ck", testing.requires.check_constraint_reflection.as_skips()),
+ ("uq", testing.requires.unique_constraint_reflection.as_skips()),
+ argnames="type_",
+ )
+ def test_long_convention_name(self, type_, metadata, connection):
+ actual_name, reflected_name = getattr(self, type_)(
+ metadata, connection
+ )
+
+ assert len(actual_name) > 255
+
+ if reflected_name is not None:
+ overlap = actual_name[0 : len(reflected_name)]
+ if len(overlap) < len(actual_name):
+ eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5])
+ else:
+ eq_(overlap, reflected_name)
+
+
+__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest")
diff --git a/lib/sqlalchemy/testing/suite/test_deprecations.py b/lib/sqlalchemy/testing/suite/test_deprecations.py
new file mode 100644
index 0000000..b36162f
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_deprecations.py
@@ -0,0 +1,145 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import select
+from ... import testing
+from ... import union
+
+
+class DeprecatedCompoundSelectTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ],
+ )
+
+ def _assert_result(self, conn, select, result, params=()):
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_plain_union(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ # note we've had to remove one use case entirely, which is this
+ # one. the Select gets its FROMS from the WHERE clause and the
+ # columns clause, but not the ORDER BY, which means the old ".c" system
+ # allowed you to "order_by(s.c.foo)" to get an unnamed column in the
+ # ORDER BY without adding the SELECT into the FROM and breaking the
+ # query. Users will have to adjust for this use case if they were doing
+ # it before.
+ def _dont_test_select_from_plain_union(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2).alias().select()
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.order_by_col_from_union
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_wo_limit_offset
+ def test_order_by_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_distinct_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).distinct()
+ s2 = select(table).where(table.c.id == 3).distinct()
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_limit_offset_aliased_selectable_in_unions(self, connection):
+ table = self.tables.some_table
+ s1 = (
+ select(table)
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select(table)
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+
+ u1 = union(s1, s2).limit(2)
+ with testing.expect_deprecated(
+ "The SelectBase.c and SelectBase.columns "
+ "attributes are deprecated"
+ ):
+ self._assert_result(
+ connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py
new file mode 100644
index 0000000..c2c17d0
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_dialect.py
@@ -0,0 +1,361 @@
+#! coding: utf-8
+
+from . import testing
+from .. import assert_raises
+from .. import config
+from .. import engines
+from .. import eq_
+from .. import fixtures
+from .. import ne_
+from .. import provide_metadata
+from ..config import requirements
+from ..provision import set_default_schema_on_connection
+from ..schema import Column
+from ..schema import Table
+from ... import bindparam
+from ... import event
+from ... import exc
+from ... import Integer
+from ... import literal_column
+from ... import select
+from ... import String
+from ...util import compat
+
+
+class ExceptionTest(fixtures.TablesTest):
+ """Test basic exception wrapping.
+
+ DBAPIs vary a lot in exception behavior so to actually anticipate
+ specific exceptions from real round trips, we need to be conservative.
+
+ """
+
+ run_deletes = "each"
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ @requirements.duplicate_key_raises_integrity_error
+ def test_integrity_error(self):
+
+ with config.db.connect() as conn:
+
+ trans = conn.begin()
+ conn.execute(
+ self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
+ )
+
+ assert_raises(
+ exc.IntegrityError,
+ conn.execute,
+ self.tables.manual_pk.insert(),
+ {"id": 1, "data": "d1"},
+ )
+
+ trans.rollback()
+
+ def test_exception_with_non_ascii(self):
+ with config.db.connect() as conn:
+ try:
+ # try to create an error message that likely has non-ascii
+ # characters in the DBAPI's message string. unfortunately
+ # there's no way to make this happen with some drivers like
+ # mysqlclient, pymysql. this at least does produce a non-
+ # ascii error message for cx_oracle, psycopg2
+ conn.execute(select(literal_column(u"méil")))
+ assert False
+ except exc.DBAPIError as err:
+ err_str = str(err)
+
+ assert str(err.orig) in str(err)
+
+ # test that we are actually getting string on Py2k, unicode
+ # on Py3k.
+ if compat.py2k:
+ assert isinstance(err_str, str)
+ else:
+ assert isinstance(err_str, str)
+
+
+class IsolationLevelTest(fixtures.TestBase):
+ __backend__ = True
+
+ __requires__ = ("isolation_level",)
+
+ def _get_non_default_isolation_level(self):
+ levels = requirements.get_isolation_levels(config)
+
+ default = levels["default"]
+ supported = levels["supported"]
+
+ s = set(supported).difference(["AUTOCOMMIT", default])
+ if s:
+ return s.pop()
+ else:
+ config.skip_test("no non-default isolation level available")
+
+ def test_default_isolation_level(self):
+ eq_(
+ config.db.dialect.default_isolation_level,
+ requirements.get_isolation_levels(config)["default"],
+ )
+
+ def test_non_default_isolation_level(self):
+ non_default = self._get_non_default_isolation_level()
+
+ with config.db.connect() as conn:
+ existing = conn.get_isolation_level()
+
+ ne_(existing, non_default)
+
+ conn.execution_options(isolation_level=non_default)
+
+ eq_(conn.get_isolation_level(), non_default)
+
+ conn.dialect.reset_isolation_level(conn.connection)
+
+ eq_(conn.get_isolation_level(), existing)
+
+ def test_all_levels(self):
+ levels = requirements.get_isolation_levels(config)
+
+ all_levels = levels["supported"]
+
+ for level in set(all_levels).difference(["AUTOCOMMIT"]):
+ with config.db.connect() as conn:
+ conn.execution_options(isolation_level=level)
+
+ eq_(conn.get_isolation_level(), level)
+
+ trans = conn.begin()
+ trans.rollback()
+
+ eq_(conn.get_isolation_level(), level)
+
+ with config.db.connect() as conn:
+ eq_(
+ conn.get_isolation_level(),
+ levels["default"],
+ )
+
+
+class AutocommitIsolationTest(fixtures.TablesTest):
+
+ run_deletes = "each"
+
+ __requires__ = ("autocommit",)
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ test_needs_acid=True,
+ )
+
+ def _test_conn_autocommits(self, conn, autocommit):
+ trans = conn.begin()
+ conn.execute(
+ self.tables.some_table.insert(), {"id": 1, "data": "some data"}
+ )
+ trans.rollback()
+
+ eq_(
+ conn.scalar(select(self.tables.some_table.c.id)),
+ 1 if autocommit else None,
+ )
+
+ with conn.begin():
+ conn.execute(self.tables.some_table.delete())
+
+ def test_autocommit_on(self, connection_no_trans):
+ conn = connection_no_trans
+ c2 = conn.execution_options(isolation_level="AUTOCOMMIT")
+ self._test_conn_autocommits(c2, True)
+
+ c2.dialect.reset_isolation_level(c2.connection)
+
+ self._test_conn_autocommits(conn, False)
+
+ def test_autocommit_off(self, connection_no_trans):
+ conn = connection_no_trans
+ self._test_conn_autocommits(conn, False)
+
+ def test_turn_autocommit_off_via_default_iso_level(
+ self, connection_no_trans
+ ):
+ conn = connection_no_trans
+ conn = conn.execution_options(isolation_level="AUTOCOMMIT")
+ self._test_conn_autocommits(conn, True)
+
+ conn.execution_options(
+ isolation_level=requirements.get_isolation_levels(config)[
+ "default"
+ ]
+ )
+ self._test_conn_autocommits(conn, False)
+
+
+class EscapingTest(fixtures.TestBase):
+ @provide_metadata
+ def test_percent_sign_round_trip(self):
+ """test that the DBAPI accommodates for escaped / nonescaped
+ percent signs in a way that matches the compiler
+
+ """
+ m = self.metadata
+ t = Table("t", m, Column("data", String(50)))
+ t.create(config.db)
+ with config.db.begin() as conn:
+ conn.execute(t.insert(), dict(data="some % value"))
+ conn.execute(t.insert(), dict(data="some %% other value"))
+
+ eq_(
+ conn.scalar(
+ select(t.c.data).where(
+ t.c.data == literal_column("'some % value'")
+ )
+ ),
+ "some % value",
+ )
+
+ eq_(
+ conn.scalar(
+ select(t.c.data).where(
+ t.c.data == literal_column("'some %% other value'")
+ )
+ ),
+ "some %% other value",
+ )
+
+
+class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase):
+ __backend__ = True
+
+ __requires__ = ("default_schema_name_switch",)
+
+ def test_control_case(self):
+ default_schema_name = config.db.dialect.default_schema_name
+
+ eng = engines.testing_engine()
+ with eng.connect():
+ pass
+
+ eq_(eng.dialect.default_schema_name, default_schema_name)
+
+ def test_wont_work_wo_insert(self):
+ default_schema_name = config.db.dialect.default_schema_name
+
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect")
+ def on_connect(dbapi_connection, connection_record):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, default_schema_name)
+
+ def test_schema_change_on_connect(self):
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect", insert=True)
+ def on_connect(dbapi_connection, connection_record):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, config.test_schema)
+
+ def test_schema_change_works_w_transactions(self):
+ eng = engines.testing_engine()
+
+ @event.listens_for(eng, "connect", insert=True)
+ def on_connect(dbapi_connection, *arg):
+ set_default_schema_on_connection(
+ config, dbapi_connection, config.test_schema
+ )
+
+ with eng.connect() as conn:
+ trans = conn.begin()
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+ trans.rollback()
+
+ what_it_should_be = eng.dialect._get_default_schema_name(conn)
+ eq_(what_it_should_be, config.test_schema)
+
+ eq_(eng.dialect.default_schema_name, config.test_schema)
+
+
+class FutureWeCanSetDefaultSchemaWEventsTest(
+ fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest
+):
+ pass
+
+
+class DifficultParametersTest(fixtures.TestBase):
+ __backend__ = True
+
+ @testing.combinations(
+ ("boring",),
+ ("per cent",),
+ ("per % cent",),
+ ("%percent",),
+ ("par(ens)",),
+ ("percent%(ens)yah",),
+ ("col:ons",),
+ ("more :: %colons%",),
+ ("/slashes/",),
+ ("more/slashes",),
+ ("q?marks",),
+ ("1param",),
+ ("1col:on",),
+ argnames="name",
+ )
+ def test_round_trip(self, name, connection, metadata):
+ t = Table(
+ "t",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column(name, String(50), nullable=False),
+ )
+
+ # table is created
+ t.create(connection)
+
+ # automatic param generated by insert
+ connection.execute(t.insert().values({"id": 1, name: "some name"}))
+
+ # automatic param generated by criteria, plus selecting the column
+ stmt = select(t.c[name]).where(t.c[name] == "some name")
+
+ eq_(connection.scalar(stmt), "some name")
+
+ # use the name in a param explicitly
+ stmt = select(t.c[name]).where(t.c[name] == bindparam(name))
+
+ row = connection.execute(stmt, {name: "some name"}).first()
+
+ # name works as the key from cursor.description
+ eq_(row._mapping[name], "some name")
diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py
new file mode 100644
index 0000000..3c22f50
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_insert.py
@@ -0,0 +1,367 @@
+from .. import config
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import select
+from ... import String
+
+
+class LastrowidTest(fixtures.TablesTest):
+ run_deletes = "each"
+
+ __backend__ = True
+
+ __requires__ = "implements_get_lastrowid", "autoincrement_insert"
+
+ __engine_options__ = {"implicit_returning": False}
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(
+ row,
+ (
+ conn.dialect.default_sequence_base,
+ "some data",
+ ),
+ )
+
+ def test_autoincrement_on_insert(self, connection):
+
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+ def test_last_inserted_id(self, connection):
+
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(r.inserted_primary_key, (pk,))
+
+ @requirements.dbapi_lastrowid
+ def test_native_lastrowid_autoinc(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ lastrowid = r.lastrowid
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(lastrowid, pk)
+
+
+class InsertBehaviorTest(fixtures.TablesTest):
+ run_deletes = "each"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+ Table(
+ "manual_pk",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("data", String(50)),
+ )
+ Table(
+ "includes_defaults",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ Column("x", Integer, default=5),
+ Column(
+ "y",
+ Integer,
+ default=literal_column("2", type_=Integer) + literal(2),
+ ),
+ )
+
+ @requirements.autoincrement_insert
+ def test_autoclose_on_insert(self):
+ if requirements.returning.enabled:
+ engine = engines.testing_engine(
+ options={"implicit_returning": False}
+ )
+ else:
+ engine = config.db
+
+ with engine.begin() as conn:
+ r = conn.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ assert r._soft_closed
+ assert not r.closed
+ assert r.is_insert
+
+ # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment
+ # an insert where the PK was taken from a row that the dialect
+ # selected, as is the case for mssql/pyodbc, will still report
+ # returns_rows as true because there's a cursor description. in that
+ # case, the row had to have been consumed at least.
+ assert not r.returns_rows or r.fetchone() is None
+
+ @requirements.returning
+ def test_autoclose_on_insert_implicit_returning(self, connection):
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ assert r._soft_closed
+ assert not r.closed
+ assert r.is_insert
+
+ # note we are experimenting with having this be True
+ # as of I8091919d45421e3f53029b8660427f844fee0228 .
+ # implicit returning has fetched the row, but it still is a
+ # "returns rows"
+ assert r.returns_rows
+
+ # and we should be able to fetchone() on it, we just get no row
+ eq_(r.fetchone(), None)
+
+ # and the keys, etc.
+ eq_(r.keys(), ["id"])
+
+ # but the dialect took in the row already. not really sure
+ # what the best behavior is.
+
+ @requirements.empty_inserts
+ def test_empty_insert(self, connection):
+ r = connection.execute(self.tables.autoinc_pk.insert())
+ assert r._soft_closed
+ assert not r.closed
+
+ r = connection.execute(
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
+ )
+ eq_(len(r.all()), 1)
+
+ @requirements.empty_inserts_executemany
+ def test_empty_insert_multiple(self, connection):
+ r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}])
+ assert r._soft_closed
+ assert not r.closed
+
+ r = connection.execute(
+ self.tables.autoinc_pk.select().where(
+ self.tables.autoinc_pk.c.id != None
+ )
+ )
+
+ eq_(len(r.all()), 3)
+
+ @requirements.insert_from_select
+ def test_insert_from_select_autoinc(self, connection):
+ src_table = self.tables.manual_pk
+ dest_table = self.tables.autoinc_pk
+ connection.execute(
+ src_table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ result = connection.execute(
+ dest_table.insert().from_select(
+ ("data",),
+ select(src_table.c.data).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(result.inserted_primary_key, (None,))
+
+ result = connection.execute(
+ select(dest_table.c.data).order_by(dest_table.c.data)
+ )
+ eq_(result.fetchall(), [("data2",), ("data3",)])
+
+ @requirements.insert_from_select
+ def test_insert_from_select_autoinc_no_rows(self, connection):
+ src_table = self.tables.manual_pk
+ dest_table = self.tables.autoinc_pk
+
+ result = connection.execute(
+ dest_table.insert().from_select(
+ ("data",),
+ select(src_table.c.data).where(
+ src_table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+ eq_(result.inserted_primary_key, (None,))
+
+ result = connection.execute(
+ select(dest_table.c.data).order_by(dest_table.c.data)
+ )
+
+ eq_(result.fetchall(), [])
+
+ @requirements.insert_from_select
+ def test_insert_from_select(self, connection):
+ table = self.tables.manual_pk
+ connection.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ connection.execute(
+ table.insert()
+ .inline()
+ .from_select(
+ ("id", "data"),
+ select(table.c.id + 5, table.c.data).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(
+ connection.execute(
+ select(table.c.data).order_by(table.c.data)
+ ).fetchall(),
+ [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)],
+ )
+
+ @requirements.insert_from_select
+ def test_insert_from_select_with_defaults(self, connection):
+ table = self.tables.includes_defaults
+ connection.execute(
+ table.insert(),
+ [
+ dict(id=1, data="data1"),
+ dict(id=2, data="data2"),
+ dict(id=3, data="data3"),
+ ],
+ )
+
+ connection.execute(
+ table.insert()
+ .inline()
+ .from_select(
+ ("id", "data"),
+ select(table.c.id + 5, table.c.data).where(
+ table.c.data.in_(["data2", "data3"])
+ ),
+ )
+ )
+
+ eq_(
+ connection.execute(
+ select(table).order_by(table.c.data, table.c.id)
+ ).fetchall(),
+ [
+ (1, "data1", 5, 4),
+ (2, "data2", 5, 4),
+ (7, "data2", 5, 4),
+ (3, "data3", 5, 4),
+ (8, "data3", 5, 4),
+ ],
+ )
+
+
+class ReturningTest(fixtures.TablesTest):
+ run_create_tables = "each"
+ __requires__ = "returning", "autoincrement_insert"
+ __backend__ = True
+
+ __engine_options__ = {"implicit_returning": True}
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(
+ row,
+ (
+ conn.dialect.default_sequence_base,
+ "some data",
+ ),
+ )
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "autoinc_pk",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("data", String(50)),
+ )
+
+ @requirements.fetch_rows_post_commit
+ def test_explicit_returning_pk_autocommit(self, connection):
+ table = self.tables.autoinc_pk
+ r = connection.execute(
+ table.insert().returning(table.c.id), dict(data="some data")
+ )
+ pk = r.first()[0]
+ fetched_pk = connection.scalar(select(table.c.id))
+ eq_(fetched_pk, pk)
+
+ def test_explicit_returning_pk_no_autocommit(self, connection):
+ table = self.tables.autoinc_pk
+ r = connection.execute(
+ table.insert().returning(table.c.id), dict(data="some data")
+ )
+ pk = r.first()[0]
+ fetched_pk = connection.scalar(select(table.c.id))
+ eq_(fetched_pk, pk)
+
+ def test_autoincrement_on_insert_implicit_returning(self, connection):
+
+ connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.autoinc_pk, connection)
+
+ def test_last_inserted_id_implicit_returning(self, connection):
+
+ r = connection.execute(
+ self.tables.autoinc_pk.insert(), dict(data="some data")
+ )
+ pk = connection.scalar(select(self.tables.autoinc_pk.c.id))
+ eq_(r.inserted_primary_key, (pk,))
+
+
+__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest")
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
new file mode 100644
index 0000000..459a4d8
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -0,0 +1,1738 @@
+import operator
+import re
+
+import sqlalchemy as sa
+from .. import config
+from .. import engines
+from .. import eq_
+from .. import expect_warnings
+from .. import fixtures
+from .. import is_
+from ..provision import get_temp_table_name
+from ..provision import temp_table_keyword_args
+from ..schema import Column
+from ..schema import Table
+from ... import event
+from ... import ForeignKey
+from ... import func
+from ... import Identity
+from ... import inspect
+from ... import Integer
+from ... import MetaData
+from ... import String
+from ... import testing
+from ... import types as sql_types
+from ...schema import DDL
+from ...schema import Index
+from ...sql.elements import quoted_name
+from ...sql.schema import BLANK_SCHEMA
+from ...testing import is_false
+from ...testing import is_true
+
+
+metadata, users = None, None
+
+
+class HasTableTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ if testing.requires.schemas.enabled:
+ Table(
+ "test_table_s",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ schema=config.test_schema,
+ )
+
+ def test_has_table(self):
+ with config.db.begin() as conn:
+ is_true(config.db.dialect.has_table(conn, "test_table"))
+ is_false(config.db.dialect.has_table(conn, "test_table_s"))
+ is_false(config.db.dialect.has_table(conn, "nonexistent_table"))
+
+ @testing.requires.schemas
+ def test_has_table_schema(self):
+ with config.db.begin() as conn:
+ is_false(
+ config.db.dialect.has_table(
+ conn, "test_table", schema=config.test_schema
+ )
+ )
+ is_true(
+ config.db.dialect.has_table(
+ conn, "test_table_s", schema=config.test_schema
+ )
+ )
+ is_false(
+ config.db.dialect.has_table(
+ conn, "nonexistent_table", schema=config.test_schema
+ )
+ )
+
+
+class HasIndexTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ tt = Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Index("my_idx", tt.c.data)
+
+ if testing.requires.schemas.enabled:
+ tt = Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ schema=config.test_schema,
+ )
+ Index("my_idx_s", tt.c.data)
+
+ def test_has_index(self):
+ with config.db.begin() as conn:
+ assert config.db.dialect.has_index(conn, "test_table", "my_idx")
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "my_idx_s"
+ )
+ assert not config.db.dialect.has_index(
+ conn, "nonexistent_table", "my_idx"
+ )
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "nonexistent_idx"
+ )
+
+ @testing.requires.schemas
+ def test_has_index_schema(self):
+ with config.db.begin() as conn:
+ assert config.db.dialect.has_index(
+ conn, "test_table", "my_idx_s", schema=config.test_schema
+ )
+ assert not config.db.dialect.has_index(
+ conn, "test_table", "my_idx", schema=config.test_schema
+ )
+ assert not config.db.dialect.has_index(
+ conn,
+ "nonexistent_table",
+ "my_idx_s",
+ schema=config.test_schema,
+ )
+ assert not config.db.dialect.has_index(
+ conn,
+ "test_table",
+ "nonexistent_idx_s",
+ schema=config.test_schema,
+ )
+
+
+class QuotedNameArgumentTest(fixtures.TablesTest):
+ run_create_tables = "once"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "quote ' one",
+ metadata,
+ Column("id", Integer),
+ Column("name", String(50)),
+ Column("data", String(50)),
+ Column("related_id", Integer),
+ sa.PrimaryKeyConstraint("id", name="pk quote ' one"),
+ sa.Index("ix quote ' one", "name"),
+ sa.UniqueConstraint(
+ "data",
+ name="uq quote' one",
+ ),
+ sa.ForeignKeyConstraint(
+ ["id"], ["related.id"], name="fk quote ' one"
+ ),
+ sa.CheckConstraint("name != 'foo'", name="ck quote ' one"),
+ comment=r"""quote ' one comment""",
+ test_needs_fk=True,
+ )
+
+ if testing.requires.symbol_names_w_double_quote.enabled:
+ Table(
+ 'quote " two',
+ metadata,
+ Column("id", Integer),
+ Column("name", String(50)),
+ Column("data", String(50)),
+ Column("related_id", Integer),
+ sa.PrimaryKeyConstraint("id", name='pk quote " two'),
+ sa.Index('ix quote " two', "name"),
+ sa.UniqueConstraint(
+ "data",
+ name='uq quote" two',
+ ),
+ sa.ForeignKeyConstraint(
+ ["id"], ["related.id"], name='fk quote " two'
+ ),
+ sa.CheckConstraint("name != 'foo'", name='ck quote " two '),
+ comment=r"""quote " two comment""",
+ test_needs_fk=True,
+ )
+
+ Table(
+ "related",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("related", Integer),
+ test_needs_fk=True,
+ )
+
+ if testing.requires.view_column_reflection.enabled:
+
+ if testing.requires.symbol_names_w_double_quote.enabled:
+ names = [
+ "quote ' one",
+ 'quote " two',
+ ]
+ else:
+ names = [
+ "quote ' one",
+ ]
+ for name in names:
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (
+ config.db.dialect.identifier_preparer.quote(
+ "view %s" % name
+ ),
+ config.db.dialect.identifier_preparer.quote(name),
+ )
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata,
+ "before_drop",
+ DDL(
+ "DROP VIEW %s"
+ % config.db.dialect.identifier_preparer.quote(
+ "view %s" % name
+ )
+ ),
+ )
+
+ def quote_fixtures(fn):
+ return testing.combinations(
+ ("quote ' one",),
+ ('quote " two', testing.requires.symbol_names_w_double_quote),
+ )(fn)
+
+ @quote_fixtures
+ def test_get_table_options(self, name):
+ insp = inspect(config.db)
+
+ insp.get_table_options(name)
+
+ @quote_fixtures
+ @testing.requires.view_column_reflection
+ def test_get_view_definition(self, name):
+ insp = inspect(config.db)
+ assert insp.get_view_definition("view %s" % name)
+
+ @quote_fixtures
+ def test_get_columns(self, name):
+ insp = inspect(config.db)
+ assert insp.get_columns(name)
+
+ @quote_fixtures
+ def test_get_pk_constraint(self, name):
+ insp = inspect(config.db)
+ assert insp.get_pk_constraint(name)
+
+ @quote_fixtures
+ def test_get_foreign_keys(self, name):
+ insp = inspect(config.db)
+ assert insp.get_foreign_keys(name)
+
+ @quote_fixtures
+ def test_get_indexes(self, name):
+ insp = inspect(config.db)
+ assert insp.get_indexes(name)
+
+ @quote_fixtures
+ @testing.requires.unique_constraint_reflection
+ def test_get_unique_constraints(self, name):
+ insp = inspect(config.db)
+ assert insp.get_unique_constraints(name)
+
+ @quote_fixtures
+ @testing.requires.comment_reflection
+ def test_get_table_comment(self, name):
+ insp = inspect(config.db)
+ assert insp.get_table_comment(name)
+
+ @quote_fixtures
+ @testing.requires.check_constraint_reflection
+ def test_get_check_constraints(self, name):
+ insp = inspect(config.db)
+ assert insp.get_check_constraints(name)
+
+
+class ComponentReflectionTest(fixtures.TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+
+ @classmethod
+ def setup_bind(cls):
+ if config.requirements.independent_connections.enabled:
+ from sqlalchemy import pool
+
+ return engines.testing_engine(
+ options=dict(poolclass=pool.StaticPool, scope="class"),
+ )
+ else:
+ return config.db
+
+ @classmethod
+ def define_tables(cls, metadata):
+ cls.define_reflected_tables(metadata, None)
+ if testing.requires.schemas.enabled:
+ cls.define_reflected_tables(metadata, testing.config.test_schema)
+
+ @classmethod
+ def define_reflected_tables(cls, metadata, schema):
+ if schema:
+ schema_prefix = schema + "."
+ else:
+ schema_prefix = ""
+
+ if testing.requires.self_referential_foreign_keys.enabled:
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ Column(
+ "parent_user_id",
+ sa.Integer,
+ sa.ForeignKey(
+ "%susers.user_id" % schema_prefix, name="user_id_fk"
+ ),
+ ),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ else:
+ users = Table(
+ "users",
+ metadata,
+ Column("user_id", sa.INT, primary_key=True),
+ Column("test1", sa.CHAR(5), nullable=False),
+ Column("test2", sa.Float(5), nullable=False),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ Table(
+ "dingalings",
+ metadata,
+ Column("dingaling_id", sa.Integer, primary_key=True),
+ Column(
+ "address_id",
+ sa.Integer,
+ sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ ),
+ Column("data", sa.String(30)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "email_addresses",
+ metadata,
+ Column("address_id", sa.Integer),
+ Column(
+ "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
+ ),
+ Column("email_address", sa.String(20)),
+ sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "comment_test",
+ metadata,
+ Column("id", sa.Integer, primary_key=True, comment="id comment"),
+ Column("data", sa.String(20), comment="data % comment"),
+ Column(
+ "d2",
+ sa.String(20),
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ schema=schema,
+ comment=r"""the test % ' " \ table comment""",
+ )
+
+ if testing.requires.cross_schema_fk_reflection.enabled:
+ if schema is None:
+ Table(
+ "local_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
+ Column(
+ "remote_id",
+ ForeignKey(
+ "%s.remote_table_2.id" % testing.config.test_schema
+ ),
+ ),
+ test_needs_fk=True,
+ schema=config.db.dialect.default_schema_name,
+ )
+ else:
+ Table(
+ "remote_table",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column(
+ "local_id",
+ ForeignKey(
+ "%s.local_table.id"
+ % config.db.dialect.default_schema_name
+ ),
+ ),
+ Column("data", sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+ Table(
+ "remote_table_2",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("data", sa.String(20)),
+ schema=schema,
+ test_needs_fk=True,
+ )
+
+ if testing.requires.index_reflection.enabled:
+ cls.define_index(metadata, users)
+
+ if not schema:
+ # test_needs_fk is at the moment to force MySQL InnoDB
+ noncol_idx_test_nopk = Table(
+ "noncol_idx_test_nopk",
+ metadata,
+ Column("q", sa.String(5)),
+ test_needs_fk=True,
+ )
+
+ noncol_idx_test_pk = Table(
+ "noncol_idx_test_pk",
+ metadata,
+ Column("id", sa.Integer, primary_key=True),
+ Column("q", sa.String(5)),
+ test_needs_fk=True,
+ )
+
+ if testing.requires.indexes_with_ascdesc.enabled:
+ Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
+ Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
+
+ if testing.requires.view_column_reflection.enabled:
+ cls.define_views(metadata, schema)
+ if not schema and testing.requires.temp_table_reflection.enabled:
+ cls.define_temp_tables(metadata)
+
+ @classmethod
+ def define_temp_tables(cls, metadata):
+ kw = temp_table_keyword_args(config, config.db)
+ table_name = get_temp_table_name(
+ config, config.db, "user_tmp_%s" % config.ident
+ )
+ user_tmp = Table(
+ table_name,
+ metadata,
+ Column("id", sa.INT, primary_key=True),
+ Column("name", sa.VARCHAR(50)),
+ Column("foo", sa.INT),
+ # disambiguate temp table unique constraint names. this is
+ # pretty arbitrary for a generic dialect however we are doing
+ # it to suit SQL Server which will produce name conflicts for
+ # unique constraints created against temp tables in different
+ # databases.
+ # https://www.arbinada.com/en/node/1645
+ sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident),
+ sa.Index("user_tmp_ix", "foo"),
+ **kw
+ )
+ if (
+ testing.requires.view_reflection.enabled
+ and testing.requires.temporary_views.enabled
+ ):
+ event.listen(
+ user_tmp,
+ "after_create",
+ DDL(
+ "create temporary view user_tmp_v as "
+ "select * from user_tmp_%s" % config.ident
+ ),
+ )
+ event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
+
+ @classmethod
+ def define_index(cls, metadata, users):
+ Index("users_t_idx", users.c.test1, users.c.test2)
+ Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1)
+
+ @classmethod
+ def define_views(cls, metadata, schema):
+ for table_name in ("users", "email_addresses"):
+ fullname = table_name
+ if schema:
+ fullname = "%s.%s" % (schema, table_name)
+ view_name = fullname + "_v"
+ query = "CREATE VIEW %s AS SELECT * FROM %s" % (
+ view_name,
+ fullname,
+ )
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
+ )
+
+ @testing.requires.schema_reflection
+ def test_get_schema_names(self):
+ insp = inspect(self.bind)
+
+ self.assert_(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_get_schema_names_w_translate_map(self, connection):
+ """test #7300"""
+
+ connection = connection.execution_options(
+ schema_translate_map={
+ "foo": "bar",
+ BLANK_SCHEMA: testing.config.test_schema,
+ }
+ )
+ insp = inspect(connection)
+
+ self.assert_(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_dialect_initialize(self):
+ engine = engines.testing_engine()
+ inspect(engine)
+ assert hasattr(engine.dialect, "default_schema_name")
+
+ @testing.requires.schema_reflection
+ def test_get_default_schema_name(self):
+ insp = inspect(self.bind)
+ eq_(insp.default_schema_name, self.bind.dialect.default_schema_name)
+
+ @testing.requires.foreign_key_constraint_reflection
+ @testing.combinations(
+ (None, True, False, False),
+ (None, True, False, True, testing.requires.schemas),
+ ("foreign_key", True, False, False),
+ (None, False, True, False),
+ (None, False, True, True, testing.requires.schemas),
+ (None, True, True, False),
+ (None, True, True, True, testing.requires.schemas),
+ argnames="order_by,include_plain,include_views,use_schema",
+ )
+ def test_get_table_names(
+ self, connection, order_by, include_plain, include_views, use_schema
+ ):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ _ignore_tables = [
+ "comment_test",
+ "noncol_idx_test_pk",
+ "noncol_idx_test_nopk",
+ "local_table",
+ "remote_table",
+ "remote_table_2",
+ ]
+
+ insp = inspect(connection)
+
+ if include_views:
+ table_names = insp.get_view_names(schema)
+ table_names.sort()
+ answer = ["email_addresses_v", "users_v"]
+ eq_(sorted(table_names), answer)
+
+ if include_plain:
+ if order_by:
+ tables = [
+ rec[0]
+ for rec in insp.get_sorted_table_and_fkc_names(schema)
+ if rec[0]
+ ]
+ else:
+ tables = insp.get_table_names(schema)
+ table_names = [t for t in tables if t not in _ignore_tables]
+
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
+ eq_(table_names, answer)
+ else:
+ answer = ["dingalings", "email_addresses", "users"]
+ eq_(sorted(table_names), answer)
+
+ @testing.requires.temp_table_names
+ def test_get_temp_table_names(self):
+ insp = inspect(self.bind)
+ temp_table_names = insp.get_temp_table_names()
+ eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident])
+
+ @testing.requires.view_reflection
+ @testing.requires.temp_table_names
+ @testing.requires.temporary_views
+ def test_get_temp_view_names(self):
+ insp = inspect(self.bind)
+ temp_table_names = insp.get_temp_view_names()
+ eq_(sorted(temp_table_names), ["user_tmp_v"])
+
+ @testing.requires.comment_reflection
+ def test_get_comments(self):
+ self._test_get_comments()
+
+ @testing.requires.comment_reflection
+ @testing.requires.schemas
+ def test_get_comments_with_schema(self):
+ self._test_get_comments(testing.config.test_schema)
+
+ def _test_get_comments(self, schema=None):
+ insp = inspect(self.bind)
+
+ eq_(
+ insp.get_table_comment("comment_test", schema=schema),
+ {"text": r"""the test % ' " \ table comment"""},
+ )
+
+ eq_(insp.get_table_comment("users", schema=schema), {"text": None})
+
+ eq_(
+ [
+ {"name": rec["name"], "comment": rec["comment"]}
+ for rec in insp.get_columns("comment_test", schema=schema)
+ ],
+ [
+ {"comment": "id comment", "name": "id"},
+ {"comment": "data % comment", "name": "data"},
+ {
+ "comment": (
+ r"""Comment types type speedily ' " \ '' Fun!"""
+ ),
+ "name": "d2",
+ },
+ ],
+ )
+
+ @testing.combinations(
+ (False, False),
+ (False, True, testing.requires.schemas),
+ (True, False, testing.requires.view_reflection),
+ (
+ True,
+ True,
+ testing.requires.schemas + testing.requires.view_reflection,
+ ),
+ argnames="use_views,use_schema",
+ )
+ def test_get_columns(self, connection, use_views, use_schema):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ users, addresses = (self.tables.users, self.tables.email_addresses)
+ if use_views:
+ table_names = ["users_v", "email_addresses_v"]
+ else:
+ table_names = ["users", "email_addresses"]
+
+ insp = inspect(connection)
+ for table_name, table in zip(table_names, (users, addresses)):
+ schema_name = schema
+ cols = insp.get_columns(table_name, schema=schema_name)
+ self.assert_(len(cols) > 0, len(cols))
+
+ # should be in order
+
+ for i, col in enumerate(table.columns):
+ eq_(col.name, cols[i]["name"])
+ ctype = cols[i]["type"].__class__
+ ctype_def = col.type
+ if isinstance(ctype_def, sa.types.TypeEngine):
+ ctype_def = ctype_def.__class__
+
+ # Oracle returns Date for DateTime.
+
+ if testing.against("oracle") and ctype_def in (
+ sql_types.Date,
+ sql_types.DateTime,
+ ):
+ ctype_def = sql_types.Date
+
+ # assert that the desired type and return type share
+ # a base within one of the generic types.
+
+ self.assert_(
+ len(
+ set(ctype.__mro__)
+ .intersection(ctype_def.__mro__)
+ .intersection(
+ [
+ sql_types.Integer,
+ sql_types.Numeric,
+ sql_types.DateTime,
+ sql_types.Date,
+ sql_types.Time,
+ sql_types.String,
+ sql_types._Binary,
+ ]
+ )
+ )
+ > 0,
+ "%s(%s), %s(%s)"
+ % (col.name, col.type, cols[i]["name"], ctype),
+ )
+
+ if not col.primary_key:
+ assert cols[i]["default"] is None
+
+ @testing.requires.temp_table_reflection
+ def test_get_temp_table_columns(self):
+ table_name = get_temp_table_name(
+ config, self.bind, "user_tmp_%s" % config.ident
+ )
+ user_tmp = self.tables[table_name]
+ insp = inspect(self.bind)
+ cols = insp.get_columns(table_name)
+ self.assert_(len(cols) > 0, len(cols))
+
+ for i, col in enumerate(user_tmp.columns):
+ eq_(col.name, cols[i]["name"])
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.view_column_reflection
+ @testing.requires.temporary_views
+ def test_get_temp_view_columns(self):
+ insp = inspect(self.bind)
+ cols = insp.get_columns("user_tmp_v")
+ eq_([col["name"] for col in cols], ["id", "name", "foo"])
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ @testing.requires.primary_key_constraint_reflection
+ def test_get_pk_constraint(self, connection, use_schema):
+ if use_schema:
+ schema = testing.config.test_schema
+ else:
+ schema = None
+
+ users, addresses = self.tables.users, self.tables.email_addresses
+ insp = inspect(connection)
+
+ users_cons = insp.get_pk_constraint(users.name, schema=schema)
+ users_pkeys = users_cons["constrained_columns"]
+ eq_(users_pkeys, ["user_id"])
+
+ addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
+ addr_pkeys = addr_cons["constrained_columns"]
+ eq_(addr_pkeys, ["address_id"])
+
+ with testing.requires.reflects_pk_names.fail_if():
+ eq_(addr_cons["name"], "email_ad_pk")
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ @testing.requires.foreign_key_constraint_reflection
+ def test_get_foreign_keys(self, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ users, addresses = (self.tables.users, self.tables.email_addresses)
+ insp = inspect(connection)
+ expected_schema = schema
+ # users
+
+ if testing.requires.self_referential_foreign_keys.enabled:
+ users_fkeys = insp.get_foreign_keys(users.name, schema=schema)
+ fkey1 = users_fkeys[0]
+
+ with testing.requires.named_constraints.fail_if():
+ eq_(fkey1["name"], "user_id_fk")
+
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ if testing.requires.self_referential_foreign_keys.enabled:
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
+
+ # addresses
+ addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
+ fkey1 = addr_fkeys[0]
+
+ with testing.requires.implicitly_named_constraints.fail_if():
+ self.assert_(fkey1["name"] is not None)
+
+ eq_(fkey1["referred_schema"], expected_schema)
+ eq_(fkey1["referred_table"], users.name)
+ eq_(fkey1["referred_columns"], ["user_id"])
+ eq_(fkey1["constrained_columns"], ["remote_user_id"])
+
+ @testing.requires.cross_schema_fk_reflection
+ @testing.requires.schemas
+ def test_get_inter_schema_foreign_keys(self):
+ local_table, remote_table, remote_table_2 = self.tables(
+ "%s.local_table" % self.bind.dialect.default_schema_name,
+ "%s.remote_table" % testing.config.test_schema,
+ "%s.remote_table_2" % testing.config.test_schema,
+ )
+
+ insp = inspect(self.bind)
+
+ local_fkeys = insp.get_foreign_keys(local_table.name)
+ eq_(len(local_fkeys), 1)
+
+ fkey1 = local_fkeys[0]
+ eq_(fkey1["referred_schema"], testing.config.test_schema)
+ eq_(fkey1["referred_table"], remote_table_2.name)
+ eq_(fkey1["referred_columns"], ["id"])
+ eq_(fkey1["constrained_columns"], ["remote_id"])
+
+ remote_fkeys = insp.get_foreign_keys(
+ remote_table.name, schema=testing.config.test_schema
+ )
+ eq_(len(remote_fkeys), 1)
+
+ fkey2 = remote_fkeys[0]
+
+ assert fkey2["referred_schema"] in (
+ None,
+ self.bind.dialect.default_schema_name,
+ )
+ eq_(fkey2["referred_table"], local_table.name)
+ eq_(fkey2["referred_columns"], ["id"])
+ eq_(fkey2["constrained_columns"], ["local_id"])
+
+ def _assert_insp_indexes(self, indexes, expected_indexes):
+ index_names = [d["name"] for d in indexes]
+ for e_index in expected_indexes:
+ assert e_index["name"] in index_names
+ index = indexes[index_names.index(e_index["name"])]
+ for key in e_index:
+ eq_(e_index[key], index[key])
+
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ def test_get_indexes(self, connection, use_schema):
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ # The database may decide to create indexes for foreign keys, etc.
+ # so there may be more indexes than expected.
+ insp = inspect(self.bind)
+ indexes = insp.get_indexes("users", schema=schema)
+ expected_indexes = [
+ {
+ "unique": False,
+ "column_names": ["test1", "test2"],
+ "name": "users_t_idx",
+ },
+ {
+ "unique": False,
+ "column_names": ["user_id", "test2", "test1"],
+ "name": "users_all_idx",
+ },
+ ]
+ self._assert_insp_indexes(indexes, expected_indexes)
+
+ @testing.combinations(
+ ("noncol_idx_test_nopk", "noncol_idx_nopk"),
+ ("noncol_idx_test_pk", "noncol_idx_pk"),
+ argnames="tname,ixname",
+ )
+ @testing.requires.index_reflection
+ @testing.requires.indexes_with_ascdesc
+ def test_get_noncol_index(self, connection, tname, ixname):
+ insp = inspect(connection)
+ indexes = insp.get_indexes(tname)
+
+ # reflecting an index that has "x DESC" in it as the column.
+ # the DB may or may not give us "x", but make sure we get the index
+ # back, it has a name, it's connected to the table.
+ expected_indexes = [{"unique": False, "name": ixname}]
+ self._assert_insp_indexes(indexes, expected_indexes)
+
+ t = Table(tname, MetaData(), autoload_with=connection)
+ eq_(len(t.indexes), 1)
+ is_(list(t.indexes)[0].table, t)
+ eq_(list(t.indexes)[0].name, ixname)
+
+ @testing.requires.temp_table_reflection
+ @testing.requires.unique_constraint_reflection
+ def test_get_temp_table_unique_constraints(self):
+ insp = inspect(self.bind)
+ reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident)
+ for refl in reflected:
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ refl.pop("duplicates_index", None)
+ eq_(
+ reflected,
+ [
+ {
+ "column_names": ["name"],
+ "name": "user_tmp_uq_%s" % config.ident,
+ }
+ ],
+ )
+
+ @testing.requires.temp_table_reflect_indexes
+ def test_get_temp_table_indexes(self):
+ insp = inspect(self.bind)
+ table_name = get_temp_table_name(
+ config, config.db, "user_tmp_%s" % config.ident
+ )
+ indexes = insp.get_indexes(table_name)
+ for ind in indexes:
+ ind.pop("dialect_options", None)
+ expected = [
+ {"unique": False, "column_names": ["foo"], "name": "user_tmp_ix"}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ eq_(
+ [idx for idx in indexes if idx["name"] == "user_tmp_ix"],
+ expected,
+ )
+
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ @testing.requires.unique_constraint_reflection
+ def test_get_unique_constraints(self, metadata, connection, use_schema):
+ # SQLite dialect needs to parse the names of the constraints
+ # separately from what it gets from PRAGMA index_list(), and
+ # then matches them up. so same set of column_names in two
+ # constraints will confuse it. Perhaps we should no longer
+ # bother with index_list() here since we have the whole
+ # CREATE TABLE?
+
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ uniques = sorted(
+ [
+ {"name": "unique_a", "column_names": ["a"]},
+ {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]},
+ {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]},
+ {"name": "unique_asc_key", "column_names": ["asc", "key"]},
+ {"name": "i.have.dots", "column_names": ["b"]},
+ {"name": "i have spaces", "column_names": ["c"]},
+ ],
+ key=operator.itemgetter("name"),
+ )
+ table = Table(
+ "testtbl",
+ metadata,
+ Column("a", sa.String(20)),
+ Column("b", sa.String(30)),
+ Column("c", sa.Integer),
+ # reserved identifiers
+ Column("asc", sa.String(30)),
+ Column("key", sa.String(30)),
+ schema=schema,
+ )
+ for uc in uniques:
+ table.append_constraint(
+ sa.UniqueConstraint(*uc["column_names"], name=uc["name"])
+ )
+ table.create(connection)
+
+ inspector = inspect(connection)
+ reflected = sorted(
+ inspector.get_unique_constraints("testtbl", schema=schema),
+ key=operator.itemgetter("name"),
+ )
+
+ names_that_duplicate_index = set()
+
+ for orig, refl in zip(uniques, reflected):
+ # Different dialects handle duplicate index and constraints
+ # differently, so ignore this flag
+ dupe = refl.pop("duplicates_index", None)
+ if dupe:
+ names_that_duplicate_index.add(dupe)
+ eq_(orig, refl)
+
+ reflected_metadata = MetaData()
+ reflected = Table(
+ "testtbl",
+ reflected_metadata,
+ autoload_with=connection,
+ schema=schema,
+ )
+
+ # test "deduplicates for index" logic. MySQL and Oracle
+ # "unique constraints" are actually unique indexes (with possible
+ # exception of a unique that is a dupe of another one in the case
+ # of Oracle). make sure # they aren't duplicated.
+ idx_names = set([idx.name for idx in reflected.indexes])
+ uq_names = set(
+ [
+ uq.name
+ for uq in reflected.constraints
+ if isinstance(uq, sa.UniqueConstraint)
+ ]
+ ).difference(["unique_c_a_b"])
+
+ assert not idx_names.intersection(uq_names)
+ if names_that_duplicate_index:
+ eq_(names_that_duplicate_index, idx_names)
+ eq_(uq_names, set())
+
+ @testing.requires.view_reflection
+ @testing.combinations(
+ (False,), (True, testing.requires.schemas), argnames="use_schema"
+ )
+ def test_get_view_definition(self, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ view_name1 = "users_v"
+ view_name2 = "email_addresses_v"
+ insp = inspect(connection)
+ v1 = insp.get_view_definition(view_name1, schema=schema)
+ self.assert_(v1)
+ v2 = insp.get_view_definition(view_name2, schema=schema)
+ self.assert_(v2)
+
+ # why is this here if it's PG specific ?
+ @testing.combinations(
+ ("users", False),
+ ("users", True, testing.requires.schemas),
+ argnames="table_name,use_schema",
+ )
+ @testing.only_on("postgresql", "PG specific feature")
+ def test_get_table_oid(self, connection, table_name, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ insp = inspect(connection)
+ oid = insp.get_table_oid(table_name, schema)
+ self.assert_(isinstance(oid, int))
+
+ @testing.requires.table_reflection
+ def test_autoincrement_col(self):
+ """test that 'autoincrement' is reflected according to sqla's policy.
+
+ Don't mark this test as unsupported for any backend !
+
+ (technically it fails with MySQL InnoDB since "id" comes before "id2")
+
+ A backend is better off not returning "autoincrement" at all,
+ instead of potentially returning "False" for an auto-incrementing
+ primary key column.
+
+ """
+
+ insp = inspect(self.bind)
+
+ for tname, cname in [
+ ("users", "user_id"),
+ ("email_addresses", "address_id"),
+ ("dingalings", "dingaling_id"),
+ ]:
+ cols = insp.get_columns(tname)
+ id_ = {c["name"]: c for c in cols}[cname]
+ assert id_.get("autoincrement", True)
+
+
+class TableNoColumnsTest(fixtures.TestBase):
+ __requires__ = ("reflect_tables_no_columns",)
+ __backend__ = True
+
+ @testing.fixture
+ def table_no_columns(self, connection, metadata):
+ Table("empty", metadata)
+ metadata.create_all(connection)
+
+ @testing.fixture
+ def view_no_columns(self, connection, metadata):
+ Table("empty", metadata)
+ metadata.create_all(connection)
+
+ Table("empty", metadata)
+ event.listen(
+ metadata,
+ "after_create",
+ DDL("CREATE VIEW empty_v AS SELECT * FROM empty"),
+ )
+
+ # for transactional DDL the transaction is rolled back before this
+ # drop statement is invoked
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW IF EXISTS empty_v")
+ )
+ metadata.create_all(connection)
+
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_table_no_columns(self, connection, table_no_columns):
+ t2 = Table("empty", MetaData(), autoload_with=connection)
+ eq_(list(t2.c), [])
+
+ @testing.requires.reflect_tables_no_columns
+ def test_get_columns_table_no_columns(self, connection, table_no_columns):
+ eq_(inspect(connection).get_columns("empty"), [])
+
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_incl_table_no_columns(self, connection, table_no_columns):
+ m = MetaData()
+ m.reflect(connection)
+ assert set(m.tables).intersection(["empty"])
+
+ @testing.requires.views
+ @testing.requires.reflect_tables_no_columns
+ def test_reflect_view_no_columns(self, connection, view_no_columns):
+ t2 = Table("empty_v", MetaData(), autoload_with=connection)
+ eq_(list(t2.c), [])
+
+ @testing.requires.views
+ @testing.requires.reflect_tables_no_columns
+ def test_get_columns_view_no_columns(self, connection, view_no_columns):
+ eq_(inspect(connection).get_columns("empty_v"), [])
+
+
+class ComponentReflectionTestExtra(fixtures.TestBase):
+
+ __backend__ = True
+
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ @testing.requires.check_constraint_reflection
+ def test_get_check_constraints(self, metadata, connection, use_schema):
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+
+ Table(
+ "sa_cc",
+ metadata,
+ Column("a", Integer()),
+ sa.CheckConstraint("a > 1 AND a < 5", name="cc1"),
+ sa.CheckConstraint(
+ "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing"
+ ),
+ schema=schema,
+ )
+
+ metadata.create_all(connection)
+
+ inspector = inspect(connection)
+ reflected = sorted(
+ inspector.get_check_constraints("sa_cc", schema=schema),
+ key=operator.itemgetter("name"),
+ )
+
+ # trying to minimize effect of quoting, parenthesis, etc.
+ # may need to add more to this as new dialects get CHECK
+ # constraint reflection support
+ def normalize(sqltext):
+ return " ".join(
+ re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I)
+ )
+
+ reflected = [
+ {"name": item["name"], "sqltext": normalize(item["sqltext"])}
+ for item in reflected
+ ]
+ eq_(
+ reflected,
+ [
+ {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"},
+ {"name": "cc1", "sqltext": "a > 1 and a < 5"},
+ ],
+ )
+
+ @testing.requires.indexes_with_expressions
+ def test_reflect_expression_based_indexes(self, metadata, connection):
+ t = Table(
+ "t",
+ metadata,
+ Column("x", String(30)),
+ Column("y", String(30)),
+ )
+
+ Index("t_idx", func.lower(t.c.x), func.lower(t.c.y))
+
+ Index("t_idx_2", t.c.x)
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ expected = [
+ {"name": "t_idx_2", "column_names": ["x"], "unique": False}
+ ]
+ if testing.requires.index_reflects_included_columns.enabled:
+ expected[0]["include_columns"] = []
+ expected[0]["dialect_options"] = {
+ "%s_include" % connection.engine.name: []
+ }
+
+ with expect_warnings(
+ "Skipped unsupported reflection of expression-based index t_idx"
+ ):
+ eq_(
+ insp.get_indexes("t"),
+ expected,
+ )
+
+ @testing.requires.index_reflects_included_columns
+ def test_reflect_covering_index(self, metadata, connection):
+ t = Table(
+ "t",
+ metadata,
+ Column("x", String(30)),
+ Column("y", String(30)),
+ )
+ idx = Index("t_idx", t.c.x)
+ idx.dialect_options[connection.engine.name]["include"] = ["y"]
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ eq_(
+ insp.get_indexes("t"),
+ [
+ {
+ "name": "t_idx",
+ "column_names": ["x"],
+ "include_columns": ["y"],
+ "unique": False,
+ "dialect_options": {
+ "%s_include" % connection.engine.name: ["y"]
+ },
+ }
+ ],
+ )
+
+ t2 = Table("t", MetaData(), autoload_with=connection)
+ eq_(
+ list(t2.indexes)[0].dialect_options[connection.engine.name][
+ "include"
+ ],
+ ["y"],
+ )
+
+ def _type_round_trip(self, connection, metadata, *types):
+ t = Table(
+ "t",
+ metadata,
+ *[Column("t%d" % i, type_) for i, type_ in enumerate(types)]
+ )
+ t.create(connection)
+
+ return [c["type"] for c in inspect(connection).get_columns("t")]
+
+ @testing.requires.table_reflection
+ def test_numeric_reflection(self, connection, metadata):
+ for typ in self._type_round_trip(
+ connection, metadata, sql_types.Numeric(18, 5)
+ ):
+ assert isinstance(typ, sql_types.Numeric)
+ eq_(typ.precision, 18)
+ eq_(typ.scale, 5)
+
+ @testing.requires.table_reflection
+ def test_varchar_reflection(self, connection, metadata):
+ typ = self._type_round_trip(
+ connection, metadata, sql_types.String(52)
+ )[0]
+ assert isinstance(typ, sql_types.String)
+ eq_(typ.length, 52)
+
+ @testing.requires.table_reflection
+ def test_nullable_reflection(self, connection, metadata):
+ t = Table(
+ "t",
+ metadata,
+ Column("a", Integer, nullable=True),
+ Column("b", Integer, nullable=False),
+ )
+ t.create(connection)
+ eq_(
+ dict(
+ (col["name"], col["nullable"])
+ for col in inspect(connection).get_columns("t")
+ ),
+ {"a": True, "b": False},
+ )
+
+ @testing.combinations(
+ (
+ None,
+ "CASCADE",
+ None,
+ testing.requires.foreign_key_constraint_option_reflection_ondelete,
+ ),
+ (
+ None,
+ None,
+ "SET NULL",
+ testing.requires.foreign_key_constraint_option_reflection_onupdate,
+ ),
+ (
+ {},
+ None,
+ "NO ACTION",
+ testing.requires.foreign_key_constraint_option_reflection_onupdate,
+ ),
+ (
+ {},
+ "NO ACTION",
+ None,
+ testing.requires.fk_constraint_option_reflection_ondelete_noaction,
+ ),
+ (
+ None,
+ None,
+ "RESTRICT",
+ testing.requires.fk_constraint_option_reflection_onupdate_restrict,
+ ),
+ (
+ None,
+ "RESTRICT",
+ None,
+ testing.requires.fk_constraint_option_reflection_ondelete_restrict,
+ ),
+ argnames="expected,ondelete,onupdate",
+ )
+ def test_get_foreign_key_options(
+ self, connection, metadata, expected, ondelete, onupdate
+ ):
+ options = {}
+ if ondelete:
+ options["ondelete"] = ondelete
+ if onupdate:
+ options["onupdate"] = onupdate
+
+ if expected is None:
+ expected = options
+
+ Table(
+ "x",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")),
+ Column("test", String(10)),
+ test_needs_fk=True,
+ )
+
+ Table(
+ "user",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(50), nullable=False),
+ Column("tid", Integer),
+ sa.ForeignKeyConstraint(
+ ["tid"], ["table.id"], name="myfk", **options
+ ),
+ test_needs_fk=True,
+ )
+
+ metadata.create_all(connection)
+
+ insp = inspect(connection)
+
+ # test 'options' is always present for a backend
+ # that can reflect these, since alembic looks for this
+ opts = insp.get_foreign_keys("table")[0]["options"]
+
+ eq_(dict((k, opts[k]) for k in opts if opts[k]), {})
+
+ opts = insp.get_foreign_keys("user")[0]["options"]
+ eq_(opts, expected)
+ # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected)
+
+
+class NormalizedNameTest(fixtures.TablesTest):
+ __requires__ = ("denormalized_names",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ quoted_name("t1", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ )
+ Table(
+ quoted_name("t2", quote=True),
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("t1id", ForeignKey("t1.id")),
+ )
+
+ def test_reflect_lowercase_forced_tables(self):
+
+ m2 = MetaData()
+ t2_ref = Table(
+ quoted_name("t2", quote=True), m2, autoload_with=config.db
+ )
+ t1_ref = m2.tables["t1"]
+ assert t2_ref.c.t1id.references(t1_ref.c.id)
+
+ m3 = MetaData()
+ m3.reflect(
+ config.db, only=lambda name, m: name.lower() in ("t1", "t2")
+ )
+ assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id)
+
+ def test_get_table_names(self):
+ tablenames = [
+ t
+ for t in inspect(config.db).get_table_names()
+ if t.lower() in ("t1", "t2")
+ ]
+
+ eq_(tablenames[0].upper(), tablenames[0].lower())
+ eq_(tablenames[1].upper(), tablenames[1].lower())
+
+
+class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest):
+ def test_computed_col_default_not_set(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_default_table")
+ col_data = {c["name"]: c for c in cols}
+ is_true("42" in col_data["with_default"]["default"])
+ is_(col_data["normal"]["default"], None)
+ is_(col_data["computed_col"]["default"], None)
+
+ def test_get_column_returns_computed(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_default_table")
+ data = {c["name"]: c for c in cols}
+ for key in ("id", "normal", "with_default"):
+ is_true("computed" not in data[key])
+ compData = data["computed_col"]
+ is_true("computed" in compData)
+ is_true("sqltext" in compData["computed"])
+ eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")
+ eq_(
+ "persisted" in compData["computed"],
+ testing.requires.computed_columns_reflect_persisted.enabled,
+ )
+ if testing.requires.computed_columns_reflect_persisted.enabled:
+ eq_(
+ compData["computed"]["persisted"],
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+
+ def check_column(self, data, column, sqltext, persisted):
+ is_true("computed" in data[column])
+ compData = data[column]["computed"]
+ eq_(self.normalize(compData["sqltext"]), sqltext)
+ if testing.requires.computed_columns_reflect_persisted.enabled:
+ is_true("persisted" in compData)
+ is_(compData["persisted"], persisted)
+
+ def test_get_column_returns_persisted(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("computed_column_table")
+ data = {c["name"]: c for c in cols}
+
+ self.check_column(
+ data,
+ "computed_no_flag",
+ "normal+42",
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+ if testing.requires.computed_columns_virtual.enabled:
+ self.check_column(
+ data,
+ "computed_virtual",
+ "normal+2",
+ False,
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ self.check_column(
+ data,
+ "computed_stored",
+ "normal-42",
+ True,
+ )
+
+ @testing.requires.schemas
+ def test_get_column_returns_persisted_with_schema(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns(
+ "computed_column_table", schema=config.test_schema
+ )
+ data = {c["name"]: c for c in cols}
+
+ self.check_column(
+ data,
+ "computed_no_flag",
+ "normal/42",
+ testing.requires.computed_columns_default_persisted.enabled,
+ )
+ if testing.requires.computed_columns_virtual.enabled:
+ self.check_column(
+ data,
+ "computed_virtual",
+ "normal/2",
+ False,
+ )
+ if testing.requires.computed_columns_stored.enabled:
+ self.check_column(
+ data,
+ "computed_stored",
+ "normal*42",
+ True,
+ )
+
+
+class IdentityReflectionTest(fixtures.TablesTest):
+ run_inserts = run_deletes = None
+
+ __backend__ = True
+ __requires__ = ("identity_columns", "table_reflection")
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "t1",
+ metadata,
+ Column("normal", Integer),
+ Column("id1", Integer, Identity()),
+ )
+ Table(
+ "t2",
+ metadata,
+ Column(
+ "id2",
+ Integer,
+ Identity(
+ always=True,
+ start=2,
+ increment=3,
+ minvalue=-2,
+ maxvalue=42,
+ cycle=True,
+ cache=4,
+ ),
+ ),
+ )
+ if testing.requires.schemas.enabled:
+ Table(
+ "t1",
+ metadata,
+ Column("normal", Integer),
+ Column("id1", Integer, Identity(always=True, start=20)),
+ schema=config.test_schema,
+ )
+
+ def check(self, value, exp, approx):
+ if testing.requires.identity_columns_standard.enabled:
+ common_keys = (
+ "always",
+ "start",
+ "increment",
+ "minvalue",
+ "maxvalue",
+ "cycle",
+ "cache",
+ )
+ for k in list(value):
+ if k not in common_keys:
+ value.pop(k)
+ if approx:
+ eq_(len(value), len(exp))
+ for k in value:
+ if k == "minvalue":
+ is_true(value[k] <= exp[k])
+ elif k in {"maxvalue", "cache"}:
+ is_true(value[k] >= exp[k])
+ else:
+ eq_(value[k], exp[k], k)
+ else:
+ eq_(value, exp)
+ else:
+ eq_(value["start"], exp["start"])
+ eq_(value["increment"], exp["increment"])
+
+ def test_reflect_identity(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("t1") + insp.get_columns("t2")
+ for col in cols:
+ if col["name"] == "normal":
+ is_false("identity" in col)
+ elif col["name"] == "id1":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=False,
+ start=1,
+ increment=1,
+ minvalue=1,
+ maxvalue=2147483647,
+ cycle=False,
+ cache=1,
+ ),
+ approx=True,
+ )
+ elif col["name"] == "id2":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=True,
+ start=2,
+ increment=3,
+ minvalue=-2,
+ maxvalue=42,
+ cycle=True,
+ cache=4,
+ ),
+ approx=False,
+ )
+
+ @testing.requires.schemas
+ def test_reflect_identity_schema(self):
+ insp = inspect(config.db)
+
+ cols = insp.get_columns("t1", schema=config.test_schema)
+ for col in cols:
+ if col["name"] == "normal":
+ is_false("identity" in col)
+ elif col["name"] == "id1":
+ is_true(col["autoincrement"] in (True, "auto"))
+ eq_(col["default"], None)
+ is_true("identity" in col)
+ self.check(
+ col["identity"],
+ dict(
+ always=True,
+ start=20,
+ increment=1,
+ minvalue=1,
+ maxvalue=2147483647,
+ cycle=False,
+ cache=1,
+ ),
+ approx=True,
+ )
+
+
+class CompositeKeyReflectionTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ tb1 = Table(
+ "tb1",
+ metadata,
+ Column("id", Integer),
+ Column("attr", Integer),
+ Column("name", sql_types.VARCHAR(20)),
+ sa.PrimaryKeyConstraint("name", "id", "attr", name="pk_tb1"),
+ schema=None,
+ test_needs_fk=True,
+ )
+ Table(
+ "tb2",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("pid", Integer),
+ Column("pattr", Integer),
+ Column("pname", sql_types.VARCHAR(20)),
+ sa.ForeignKeyConstraint(
+ ["pname", "pid", "pattr"],
+ [tb1.c.name, tb1.c.id, tb1.c.attr],
+ name="fk_tb1_name_id_attr",
+ ),
+ schema=None,
+ test_needs_fk=True,
+ )
+
+ @testing.requires.primary_key_constraint_reflection
+ def test_pk_column_order(self):
+ # test for issue #5661
+ insp = inspect(self.bind)
+ primary_key = insp.get_pk_constraint(self.tables.tb1.name)
+ eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"])
+
+ @testing.requires.foreign_key_constraint_reflection
+ def test_fk_column_order(self):
+ # test for issue #5661
+ insp = inspect(self.bind)
+ foreign_keys = insp.get_foreign_keys(self.tables.tb2.name)
+ eq_(len(foreign_keys), 1)
+ fkey1 = foreign_keys[0]
+ eq_(fkey1.get("referred_columns"), ["name", "id", "attr"])
+ eq_(fkey1.get("constrained_columns"), ["pname", "pid", "pattr"])
+
+
+__all__ = (
+ "ComponentReflectionTest",
+ "ComponentReflectionTestExtra",
+ "TableNoColumnsTest",
+ "QuotedNameArgumentTest",
+ "HasTableTest",
+ "HasIndexTest",
+ "NormalizedNameTest",
+ "ComputedReflectionTest",
+ "IdentityReflectionTest",
+ "CompositeKeyReflectionTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py
new file mode 100644
index 0000000..c41a550
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_results.py
@@ -0,0 +1,426 @@
+import datetime
+
+from .. import engines
+from .. import fixtures
+from ..assertions import eq_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import DateTime
+from ... import func
+from ... import Integer
+from ... import select
+from ... import sql
+from ... import String
+from ... import testing
+from ... import text
+from ... import util
+
+
+class RowFetchTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ Table(
+ "has_dates",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("today", DateTime),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.plain_pk.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ connection.execute(
+ cls.tables.has_dates.insert(),
+ [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}],
+ )
+
+ def test_via_attr(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row.id, 1)
+ eq_(row.data, "d1")
+
+ def test_via_string(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row._mapping["id"], 1)
+ eq_(row._mapping["data"], "d1")
+
+ def test_via_int(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row[0], 1)
+ eq_(row[1], "d1")
+
+ def test_via_col_object(self, connection):
+ row = connection.execute(
+ self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id)
+ ).first()
+
+ eq_(row._mapping[self.tables.plain_pk.c.id], 1)
+ eq_(row._mapping[self.tables.plain_pk.c.data], "d1")
+
+ @requirements.duplicate_names_in_cursor_description
+ def test_row_with_dupe_names(self, connection):
+ result = connection.execute(
+ select(
+ self.tables.plain_pk.c.data,
+ self.tables.plain_pk.c.data.label("data"),
+ ).order_by(self.tables.plain_pk.c.id)
+ )
+ row = result.first()
+ eq_(result.keys(), ["data", "data"])
+ eq_(row, ("d1", "d1"))
+
+ def test_row_w_scalar_select(self, connection):
+ """test that a scalar select as a column is returned as such
+ and that type conversion works OK.
+
+ (this is half a SQLAlchemy Core test and half to catch database
+ backends that may have unusual behavior with scalar selects.)
+
+ """
+ datetable = self.tables.has_dates
+ s = select(datetable.alias("x").c.today).scalar_subquery()
+ s2 = select(datetable.c.id, s.label("somelabel"))
+ row = connection.execute(s2).first()
+
+ eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0))
+
+
+class PercentSchemaNamesTest(fixtures.TablesTest):
+ """tests using percent signs, spaces in table and column names.
+
+ This didn't work for PostgreSQL / MySQL drivers for a long time
+ but is now supported.
+
+ """
+
+ __requires__ = ("percent_schema_names",)
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ cls.tables.percent_table = Table(
+ "percent%table",
+ metadata,
+ Column("percent%", Integer),
+ Column("spaces % more spaces", Integer),
+ )
+ cls.tables.lightweight_percent_table = sql.table(
+ "percent%table",
+ sql.column("percent%"),
+ sql.column("spaces % more spaces"),
+ )
+
+ def test_single_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ for params in [
+ {"percent%": 5, "spaces % more spaces": 12},
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ]:
+ connection.execute(percent_table.insert(), params)
+ self._assert_table(connection)
+
+ def test_executemany_roundtrip(self, connection):
+ percent_table = self.tables.percent_table
+ connection.execute(
+ percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12}
+ )
+ connection.execute(
+ percent_table.insert(),
+ [
+ {"percent%": 7, "spaces % more spaces": 11},
+ {"percent%": 9, "spaces % more spaces": 10},
+ {"percent%": 11, "spaces % more spaces": 9},
+ ],
+ )
+ self._assert_table(connection)
+
+ def _assert_table(self, conn):
+ percent_table = self.tables.percent_table
+ lightweight_percent_table = self.tables.lightweight_percent_table
+
+ for table in (
+ percent_table,
+ percent_table.alias(),
+ lightweight_percent_table,
+ lightweight_percent_table.alias(),
+ ):
+ eq_(
+ list(
+ conn.execute(table.select().order_by(table.c["percent%"]))
+ ),
+ [(5, 12), (7, 11), (9, 10), (11, 9)],
+ )
+
+ eq_(
+ list(
+ conn.execute(
+ table.select()
+ .where(table.c["spaces % more spaces"].in_([9, 10]))
+ .order_by(table.c["percent%"])
+ )
+ ),
+ [(9, 10), (11, 9)],
+ )
+
+ row = conn.execute(
+ table.select().order_by(table.c["percent%"])
+ ).first()
+ eq_(row._mapping["percent%"], 5)
+ eq_(row._mapping["spaces % more spaces"], 12)
+
+ eq_(row._mapping[table.c["percent%"]], 5)
+ eq_(row._mapping[table.c["spaces % more spaces"]], 12)
+
+ conn.execute(
+ percent_table.update().values(
+ {percent_table.c["spaces % more spaces"]: 15}
+ )
+ )
+
+ eq_(
+ list(
+ conn.execute(
+ percent_table.select().order_by(
+ percent_table.c["percent%"]
+ )
+ )
+ ),
+ [(5, 15), (7, 15), (9, 15), (11, 15)],
+ )
+
+
+class ServerSideCursorsTest(
+ fixtures.TestBase, testing.AssertsExecutionResults
+):
+
+ __requires__ = ("server_side_cursors",)
+
+ __backend__ = True
+
+ def _is_server_side(self, cursor):
+ # TODO: this is a huge issue as it prevents these tests from being
+ # usable by third party dialects.
+ if self.engine.dialect.driver == "psycopg2":
+ return bool(cursor.name)
+ elif self.engine.dialect.driver == "pymysql":
+ sscursor = __import__("pymysql.cursors").cursors.SSCursor
+ return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver in ("aiomysql", "asyncmy"):
+ return cursor.server_side
+ elif self.engine.dialect.driver == "mysqldb":
+ sscursor = __import__("MySQLdb.cursors").cursors.SSCursor
+ return isinstance(cursor, sscursor)
+ elif self.engine.dialect.driver == "mariadbconnector":
+ return not cursor.buffered
+ elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"):
+ return cursor.server_side
+ elif self.engine.dialect.driver == "pg8000":
+ return getattr(cursor, "server_side", False)
+ else:
+ return False
+
+ def _fixture(self, server_side_cursors):
+ if server_side_cursors:
+ with testing.expect_deprecated(
+ "The create_engine.server_side_cursors parameter is "
+ "deprecated and will be removed in a future release. "
+ "Please use the Connection.execution_options.stream_results "
+ "parameter."
+ ):
+ self.engine = engines.testing_engine(
+ options={"server_side_cursors": server_side_cursors}
+ )
+ else:
+ self.engine = engines.testing_engine(
+ options={"server_side_cursors": server_side_cursors}
+ )
+ return self.engine
+
+ @testing.combinations(
+ ("global_string", True, "select 1", True),
+ ("global_text", True, text("select 1"), True),
+ ("global_expr", True, select(1), True),
+ ("global_off_explicit", False, text("select 1"), False),
+ (
+ "stmt_option",
+ False,
+ select(1).execution_options(stream_results=True),
+ True,
+ ),
+ (
+ "stmt_option_disabled",
+ True,
+ select(1).execution_options(stream_results=False),
+ False,
+ ),
+ ("for_update_expr", True, select(1).with_for_update(), True),
+ # TODO: need a real requirement for this, or dont use this test
+ (
+ "for_update_string",
+ True,
+ "SELECT 1 FOR UPDATE",
+ True,
+ testing.skip_if("sqlite"),
+ ),
+ ("text_no_ss", False, text("select 42"), False),
+ (
+ "text_ss_option",
+ False,
+ text("select 42").execution_options(stream_results=True),
+ True,
+ ),
+ id_="iaaa",
+ argnames="engine_ss_arg, statement, cursor_ss_status",
+ )
+ def test_ss_cursor_status(
+ self, engine_ss_arg, statement, cursor_ss_status
+ ):
+ engine = self._fixture(engine_ss_arg)
+ with engine.begin() as conn:
+ if isinstance(statement, util.string_types):
+ result = conn.exec_driver_sql(statement)
+ else:
+ result = conn.execute(statement)
+ eq_(self._is_server_side(result.cursor), cursor_ss_status)
+ result.close()
+
+ def test_conn_option(self):
+ engine = self._fixture(False)
+
+ with engine.connect() as conn:
+ # should be enabled for this one
+ result = conn.execution_options(
+ stream_results=True
+ ).exec_driver_sql("select 1")
+ assert self._is_server_side(result.cursor)
+
+ def test_stmt_enabled_conn_option_disabled(self):
+ engine = self._fixture(False)
+
+ s = select(1).execution_options(stream_results=True)
+
+ with engine.connect() as conn:
+ # not this one
+ result = conn.execution_options(stream_results=False).execute(s)
+ assert not self._is_server_side(result.cursor)
+
+ def test_aliases_and_ss(self):
+ engine = self._fixture(False)
+ s1 = (
+ select(sql.literal_column("1").label("x"))
+ .execution_options(stream_results=True)
+ .subquery()
+ )
+
+ # options don't propagate out when subquery is used as a FROM clause
+ with engine.begin() as conn:
+ result = conn.execute(s1.select())
+ assert not self._is_server_side(result.cursor)
+ result.close()
+
+ s2 = select(1).select_from(s1)
+ with engine.begin() as conn:
+ result = conn.execute(s2)
+ assert not self._is_server_side(result.cursor)
+ result.close()
+
+ def test_roundtrip_fetchall(self, metadata):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ with engine.begin() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(test_table.insert(), dict(data="data1"))
+ connection.execute(test_table.insert(), dict(data="data2"))
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2")],
+ )
+ connection.execute(
+ test_table.update()
+ .where(test_table.c.id == 2)
+ .values(data=test_table.c.data + " updated")
+ )
+ eq_(
+ connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ ).fetchall(),
+ [(1, "data1"), (2, "data2 updated")],
+ )
+ connection.execute(test_table.delete())
+ eq_(
+ connection.scalar(
+ select(func.count("*")).select_from(test_table)
+ ),
+ 0,
+ )
+
+ def test_roundtrip_fetchmany(self, metadata):
+ md = self.metadata
+
+ engine = self._fixture(True)
+ test_table = Table(
+ "test_table",
+ md,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ with engine.begin() as connection:
+ test_table.create(connection, checkfirst=True)
+ connection.execute(
+ test_table.insert(),
+ [dict(data="data%d" % i) for i in range(1, 20)],
+ )
+
+ result = connection.execute(
+ test_table.select().order_by(test_table.c.id)
+ )
+
+ eq_(
+ result.fetchmany(5),
+ [(i, "data%d" % i) for i in range(1, 6)],
+ )
+ eq_(
+ result.fetchmany(10),
+ [(i, "data%d" % i) for i in range(6, 16)],
+ )
+ eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)])
diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py
new file mode 100644
index 0000000..82e831f
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_rowcount.py
@@ -0,0 +1,165 @@
+from sqlalchemy import bindparam
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import text
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+
+
+class RowCountTest(fixtures.TablesTest):
+ """test rowcount functionality"""
+
+ __requires__ = ("sane_rowcount",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "employees",
+ metadata,
+ Column(
+ "employee_id",
+ Integer,
+ autoincrement=False,
+ primary_key=True,
+ ),
+ Column("name", String(50)),
+ Column("department", String(1)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ cls.data = data = [
+ ("Angela", "A"),
+ ("Andrew", "A"),
+ ("Anand", "A"),
+ ("Bob", "B"),
+ ("Bobette", "B"),
+ ("Buffy", "B"),
+ ("Charlie", "C"),
+ ("Cynthia", "C"),
+ ("Chris", "C"),
+ ]
+
+ employees_table = cls.tables.employees
+ connection.execute(
+ employees_table.insert(),
+ [
+ {"employee_id": i, "name": n, "department": d}
+ for i, (n, d) in enumerate(data)
+ ],
+ )
+
+ def test_basic(self, connection):
+ employees_table = self.tables.employees
+ s = select(
+ employees_table.c.name, employees_table.c.department
+ ).order_by(employees_table.c.employee_id)
+ rows = connection.execute(s).fetchall()
+
+ eq_(rows, self.data)
+
+ def test_update_rowcount1(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 3 rows changed
+ department = employees_table.c.department
+ r = connection.execute(
+ employees_table.update().where(department == "C"),
+ {"department": "Z"},
+ )
+ assert r.rowcount == 3
+
+ def test_update_rowcount2(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 0 rows changed
+ department = employees_table.c.department
+
+ r = connection.execute(
+ employees_table.update().where(department == "C"),
+ {"department": "C"},
+ )
+ eq_(r.rowcount, 3)
+
+ @testing.requires.sane_rowcount_w_returning
+ def test_update_rowcount_return_defaults(self, connection):
+ employees_table = self.tables.employees
+
+ department = employees_table.c.department
+ stmt = (
+ employees_table.update()
+ .where(department == "C")
+ .values(name=employees_table.c.department + "Z")
+ .return_defaults()
+ )
+
+ r = connection.execute(stmt)
+ eq_(r.rowcount, 3)
+
+ def test_raw_sql_rowcount(self, connection):
+ # test issue #3622, make sure eager rowcount is called for text
+ result = connection.exec_driver_sql(
+ "update employees set department='Z' where department='C'"
+ )
+ eq_(result.rowcount, 3)
+
+ def test_text_rowcount(self, connection):
+ # test issue #3622, make sure eager rowcount is called for text
+ result = connection.execute(
+ text("update employees set department='Z' " "where department='C'")
+ )
+ eq_(result.rowcount, 3)
+
+ def test_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
+ # WHERE matches 3, 3 rows deleted
+ department = employees_table.c.department
+ r = connection.execute(
+ employees_table.delete().where(department == "C")
+ )
+ eq_(r.rowcount, 3)
+
+ @testing.requires.sane_multi_rowcount
+ def test_multi_update_rowcount(self, connection):
+ employees_table = self.tables.employees
+ stmt = (
+ employees_table.update()
+ .where(employees_table.c.name == bindparam("emp_name"))
+ .values(department="C")
+ )
+
+ r = connection.execute(
+ stmt,
+ [
+ {"emp_name": "Bob"},
+ {"emp_name": "Cynthia"},
+ {"emp_name": "nonexistent"},
+ ],
+ )
+
+ eq_(r.rowcount, 2)
+
+ @testing.requires.sane_multi_rowcount
+ def test_multi_delete_rowcount(self, connection):
+ employees_table = self.tables.employees
+
+ stmt = employees_table.delete().where(
+ employees_table.c.name == bindparam("emp_name")
+ )
+
+ r = connection.execute(
+ stmt,
+ [
+ {"emp_name": "Bob"},
+ {"emp_name": "Cynthia"},
+ {"emp_name": "nonexistent"},
+ ],
+ )
+
+ eq_(r.rowcount, 2)
diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py
new file mode 100644
index 0000000..cb78fff
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_select.py
@@ -0,0 +1,1783 @@
+import itertools
+
+from .. import AssertsCompiledSQL
+from .. import AssertsExecutionResults
+from .. import config
+from .. import fixtures
+from ..assertions import assert_raises
+from ..assertions import eq_
+from ..assertions import in_
+from ..assertsql import CursorSQL
+from ..schema import Column
+from ..schema import Table
+from ... import bindparam
+from ... import case
+from ... import column
+from ... import Computed
+from ... import exists
+from ... import false
+from ... import ForeignKey
+from ... import func
+from ... import Identity
+from ... import Integer
+from ... import literal
+from ... import literal_column
+from ... import null
+from ... import select
+from ... import String
+from ... import table
+from ... import testing
+from ... import text
+from ... import true
+from ... import tuple_
+from ... import TupleType
+from ... import union
+from ... import util
+from ... import values
+from ...exc import DatabaseError
+from ...exc import ProgrammingError
+from ...util import collections_abc
+
+
+class CollateTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(100)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "collate data1"},
+ {"id": 2, "data": "collate data2"},
+ ],
+ )
+
+ def _assert_result(self, select, result):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select).fetchall(), result)
+
+ @testing.requires.order_by_collation
+ def test_collate_order_by(self):
+ collation = testing.requires.get_order_by_collation(testing.config)
+
+ self._assert_result(
+ select(self.tables.some_table).order_by(
+ self.tables.some_table.c.data.collate(collation).asc()
+ ),
+ [(1, "collate data1"), (2, "collate data2")],
+ )
+
+
+class OrderByLabelTest(fixtures.TablesTest):
+ """Test the dialect sends appropriate ORDER BY expressions when
+ labels are used.
+
+ This essentially exercises the "supports_simple_order_by_label"
+ setting.
+
+ """
+
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("q", String(50)),
+ Column("p", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"},
+ {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"},
+ {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"},
+ ],
+ )
+
+ def _assert_result(self, select, result):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select).fetchall(), result)
+
+ def test_plain(self):
+ table = self.tables.some_table
+ lx = table.c.x.label("lx")
+ self._assert_result(select(lx).order_by(lx), [(1,), (2,), (3,)])
+
+ def test_composed_int(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select(lx).order_by(lx), [(3,), (5,), (7,)])
+
+ def test_composed_multiple(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ ly = (func.lower(table.c.q) + table.c.p).label("ly")
+ self._assert_result(
+ select(lx, ly).order_by(lx, ly.desc()),
+ [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))],
+ )
+
+ def test_plain_desc(self):
+ table = self.tables.some_table
+ lx = table.c.x.label("lx")
+ self._assert_result(select(lx).order_by(lx.desc()), [(3,), (2,), (1,)])
+
+ def test_composed_int_desc(self):
+ table = self.tables.some_table
+ lx = (table.c.x + table.c.y).label("lx")
+ self._assert_result(select(lx).order_by(lx.desc()), [(7,), (5,), (3,)])
+
+ @testing.requires.group_by_complex_expression
+ def test_group_by_composed(self):
+ table = self.tables.some_table
+ expr = (table.c.x + table.c.y).label("lx")
+ stmt = (
+ select(func.count(table.c.id), expr).group_by(expr).order_by(expr)
+ )
+ self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)])
+
+
+class ValuesExpressionTest(fixtures.TestBase):
+ __requires__ = ("table_value_constructor",)
+
+ __backend__ = True
+
+ def test_tuples(self, connection):
+ value_expr = values(
+ column("id", Integer), column("name", String), name="my_values"
+ ).data([(1, "name1"), (2, "name2"), (3, "name3")])
+
+ eq_(
+ connection.execute(select(value_expr)).all(),
+ [(1, "name1"), (2, "name2"), (3, "name3")],
+ )
+
+
+class FetchLimitOffsetTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ {"id": 5, "x": 4, "y": 6},
+ ],
+ )
+
+ def _assert_result(
+ self, connection, select, result, params=(), set_=False
+ ):
+ if set_:
+ query_res = connection.execute(select, params).fetchall()
+ eq_(len(query_res), len(result))
+ eq_(set(query_res), set(result))
+
+ else:
+ eq_(connection.execute(select, params).fetchall(), result)
+
+ def _assert_result_str(self, select, result, params=()):
+ conn = config.db.connect(close_with_result=True)
+ eq_(conn.exec_driver_sql(select, params).fetchall(), result)
+
+ def test_simple_limit(self, connection):
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id)
+ self._assert_result(
+ connection,
+ stmt.limit(2),
+ [(1, 1, 2), (2, 2, 3)],
+ )
+ self._assert_result(
+ connection,
+ stmt.limit(3),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ def test_limit_render_multiple_times(self, connection):
+ table = self.tables.some_table
+ stmt = select(table.c.id).limit(1).scalar_subquery()
+
+ u = union(select(stmt), select(stmt)).subquery().select()
+
+ self._assert_result(
+ connection,
+ u,
+ [
+ (1,),
+ ],
+ )
+
+ @testing.requires.fetch_first
+ def test_simple_fetch(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(2),
+ [(1, 1, 2), (2, 2, 3)],
+ )
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(3),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.offset
+ def test_simple_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(2),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(3),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.combinations(
+ ([(2, 0), (2, 1), (3, 2)]),
+ ([(2, 1), (2, 0), (3, 2)]),
+ ([(3, 1), (2, 1), (3, 1)]),
+ argnames="cases",
+ )
+ @testing.requires.offset
+ def test_simple_limit_offset(self, connection, cases):
+ table = self.tables.some_table
+ connection = connection.execution_options(compiled_cache={})
+
+ assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)]
+
+ for limit, offset in cases:
+ expected = assert_data[offset : offset + limit]
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(limit).offset(offset),
+ expected,
+ )
+
+ @testing.requires.fetch_first
+ def test_simple_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(2).offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(3).offset(2),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.fetch_no_order_by
+ def test_fetch_offset_no_order(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).fetch(10),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.offset
+ def test_simple_offset_zero(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(0),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(1),
+ [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.offset
+ def test_limit_offset_nobinds(self):
+ """test that 'literal binds' mode works - no bound params."""
+
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id).limit(2).offset(1)
+ sql = stmt.compile(
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
+ sql = str(sql)
+
+ self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)])
+
+ @testing.requires.fetch_first
+ def test_fetch_offset_nobinds(self):
+ """test that 'literal binds' mode works - no bound params."""
+
+ table = self.tables.some_table
+ stmt = select(table).order_by(table.c.id).fetch(2).offset(1)
+ sql = stmt.compile(
+ dialect=config.db.dialect, compile_kwargs={"literal_binds": True}
+ )
+ sql = str(sql)
+
+ self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)])
+
+ @testing.requires.bound_limit_offset
+ def test_bound_limit(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(bindparam("l")),
+ [(1, 1, 2), (2, 2, 3)],
+ params={"l": 2},
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).limit(bindparam("l")),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ params={"l": 3},
+ )
+
+ @testing.requires.bound_limit_offset
+ def test_bound_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"o": 2},
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"o": 1},
+ )
+
+ @testing.requires.bound_limit_offset
+ def test_bound_limit_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4)],
+ params={"l": 2, "o": 1},
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(bindparam("l"))
+ .offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"l": 3, "o": 2},
+ )
+
+ @testing.requires.fetch_first
+ def test_bound_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(bindparam("f"))
+ .offset(bindparam("o")),
+ [(2, 2, 3), (3, 3, 4)],
+ params={"f": 2, "o": 1},
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(bindparam("f"))
+ .offset(bindparam("o")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ params={"f": 3, "o": 2},
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .offset(literal_column("1") + literal_column("2")),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("2")),
+ [(1, 1, 2), (2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ @testing.requires.fetch_first
+ @testing.requires.fetch_expression
+ def test_expr_fetch_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(literal_column("1") + literal_column("1"))
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_simple_limit_expr_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(2)
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(3)
+ .offset(literal_column("1") + literal_column("1")),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.sql_expression_limit_offset
+ def test_expr_limit_simple_offset(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(2),
+ [(3, 3, 4), (4, 4, 5)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .limit(literal_column("1") + literal_column("1"))
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.fetch_ties
+ def test_simple_fetch_ties(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.x.desc()).fetch(1, with_ties=True),
+ [(4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.x.desc()).fetch(3, with_ties=True),
+ [(3, 3, 4), (4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_ties(self, connection):
+ table = self.tables.some_table
+ fa = connection.execute(
+ select(table)
+ .order_by(table.c.x)
+ .fetch(2, with_ties=True)
+ .offset(2)
+ ).fetchall()
+ eq_(fa[0], (3, 3, 4))
+ eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_ties_exact_number(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x)
+ .fetch(2, with_ties=True)
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x)
+ .fetch(3, with_ties=True)
+ .offset(3),
+ [(4, 4, 5), (5, 4, 6)],
+ )
+
+ @testing.requires.fetch_percent
+ def test_simple_fetch_percent(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table).order_by(table.c.id).fetch(20, percent=True),
+ [(1, 1, 2)],
+ )
+
+ @testing.requires.fetch_percent
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_percent(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.id)
+ .fetch(40, percent=True)
+ .offset(1),
+ [(2, 2, 3), (3, 3, 4)],
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_percent
+ def test_simple_fetch_percent_ties(self, connection):
+ table = self.tables.some_table
+ self._assert_result(
+ connection,
+ select(table)
+ .order_by(table.c.x.desc())
+ .fetch(20, percent=True, with_ties=True),
+ [(4, 4, 5), (5, 4, 6)],
+ set_=True,
+ )
+
+ @testing.requires.fetch_ties
+ @testing.requires.fetch_percent
+ @testing.requires.fetch_offset_with_options
+ def test_fetch_offset_percent_ties(self, connection):
+ table = self.tables.some_table
+ fa = connection.execute(
+ select(table)
+ .order_by(table.c.x)
+ .fetch(40, percent=True, with_ties=True)
+ .offset(2)
+ ).fetchall()
+ eq_(fa[0], (3, 3, 4))
+ eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)]))
+
+
+class JoinTest(fixtures.TablesTest):
+ __backend__ = True
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("a", metadata, Column("id", Integer, primary_key=True))
+ Table(
+ "b",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("a_id", ForeignKey("a.id"), nullable=False),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.a.insert(),
+ [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}],
+ )
+
+ connection.execute(
+ cls.tables.b.insert(),
+ [
+ {"id": 1, "a_id": 1},
+ {"id": 2, "a_id": 1},
+ {"id": 4, "a_id": 2},
+ {"id": 5, "a_id": 3},
+ ],
+ )
+
+ def test_inner_join_fk(self):
+ a, b = self.tables("a", "b")
+
+ stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id)
+
+ self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)])
+
+ def test_inner_join_true(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.join(b, true()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(
+ stmt,
+ [
+ (a, b, c)
+ for (a,), (b, c) in itertools.product(
+ [(1,), (2,), (3,), (4,), (5,)],
+ [(1, 1), (2, 1), (4, 2), (5, 3)],
+ )
+ ],
+ )
+
+ def test_inner_join_false(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.join(b, false()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(stmt, [])
+
+ def test_outer_join_false(self):
+ a, b = self.tables("a", "b")
+
+ stmt = (
+ select(a, b)
+ .select_from(a.outerjoin(b, false()))
+ .order_by(a.c.id, b.c.id)
+ )
+
+ self._assert_result(
+ stmt,
+ [
+ (1, None, None),
+ (2, None, None),
+ (3, None, None),
+ (4, None, None),
+ (5, None, None),
+ ],
+ )
+
+ def test_outer_join_fk(self):
+ a, b = self.tables("a", "b")
+
+ stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id)
+
+ self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)])
+
+
+class CompoundSelectTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2},
+ {"id": 2, "x": 2, "y": 3},
+ {"id": 3, "x": 3, "y": 4},
+ {"id": 4, "x": 4, "y": 5},
+ ],
+ )
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_plain_union(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_select_from_plain_union(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2)
+ s2 = select(table).where(table.c.id == 3)
+
+ u1 = union(s1, s2).alias().select()
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.order_by_col_from_union
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_wo_limit_offset
+ def test_order_by_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).order_by(table.c.id)
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_distinct_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).distinct()
+ s2 = select(table).where(table.c.id == 3).distinct()
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ @testing.requires.parens_in_union_contained_select_w_limit_offset
+ def test_limit_offset_in_unions_from_alias(self):
+ table = self.tables.some_table
+ s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id)
+ s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id)
+
+ # this necessarily has double parens
+ u1 = union(s1, s2).alias()
+ self._assert_result(
+ u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+ def test_limit_offset_aliased_selectable_in_unions(self):
+ table = self.tables.some_table
+ s1 = (
+ select(table)
+ .where(table.c.id == 2)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+ s2 = (
+ select(table)
+ .where(table.c.id == 3)
+ .limit(1)
+ .order_by(table.c.id)
+ .alias()
+ .select()
+ )
+
+ u1 = union(s1, s2).limit(2)
+ self._assert_result(
+ u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)]
+ )
+
+
+class PostCompileParamsTest(
+ AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest
+):
+ __backend__ = True
+
+ __requires__ = ("standard_cursor_sql",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "z": "z1"},
+ {"id": 2, "x": 2, "y": 3, "z": "z2"},
+ {"id": 3, "x": 3, "y": 4, "z": "z3"},
+ {"id": 4, "x": 4, "y": 5, "z": "z4"},
+ ],
+ )
+
+ def test_compile(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", literal_execute=True)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT some_table.id FROM some_table "
+ "WHERE some_table.x = __[POSTCOMPILE_q]",
+ {},
+ )
+
+ def test_compile_literal_binds(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", 10, literal_execute=True)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT some_table.id FROM some_table WHERE some_table.x = 10",
+ {},
+ literal_binds=True,
+ )
+
+ def test_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x == bindparam("q", literal_execute=True)
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=10))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE some_table.x = 10",
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ def test_execute_expanding_plus_literal_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ table.c.x.in_(bindparam("q", expanding=True, literal_execute=True))
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[5, 6, 7]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE some_table.x IN (5, 6, 7)",
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ @testing.requires.tuple_in
+ def test_execute_tuple_expanding_plus_literal_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ tuple_(table.c.x, table.c.y).in_(
+ bindparam("q", expanding=True, literal_execute=True)
+ )
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[(5, 10), (12, 18)]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE (some_table.x, some_table.y) "
+ "IN (%s(5, 10), (12, 18))"
+ % ("VALUES " if config.db.dialect.tuple_in_values else ""),
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+ @testing.requires.tuple_in
+ def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self):
+ table = self.tables.some_table
+
+ stmt = select(table.c.id).where(
+ tuple_(table.c.x, table.c.z).in_(
+ bindparam("q", expanding=True, literal_execute=True)
+ )
+ )
+
+ with self.sql_execution_asserter() as asserter:
+ with config.db.connect() as conn:
+ conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")]))
+
+ asserter.assert_(
+ CursorSQL(
+ "SELECT some_table.id \nFROM some_table "
+ "\nWHERE (some_table.x, some_table.z) "
+ "IN (%s(5, 'z1'), (12, 'z3'))"
+ % ("VALUES " if config.db.dialect.tuple_in_values else ""),
+ () if config.db.dialect.positional else {},
+ )
+ )
+
+
+class ExpandingBoundInTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", Integer),
+ Column("y", Integer),
+ Column("z", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "x": 1, "y": 2, "z": "z1"},
+ {"id": 2, "x": 2, "y": 3, "z": "z2"},
+ {"id": 3, "x": 3, "y": 4, "z": "z3"},
+ {"id": 4, "x": 4, "y": 5, "z": "z4"},
+ ],
+ )
+
+ def _assert_result(self, select, result, params=()):
+ with config.db.connect() as conn:
+ eq_(conn.execute(select, params).fetchall(), result)
+
+ def test_multiple_empty_sets_bindparam(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .where(table.c.y.in_(bindparam("p")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": [], "p": []})
+
+ def test_multiple_empty_sets_direct(self):
+ # test that any anonymous aliasing used by the dialect
+ # is fine with duplicates
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([]))
+ .where(table.c.y.in_([]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_heterogeneous_tuples_direct(self):
+ table = self.tables.some_table
+
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)])
+ go([], [])
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ @testing.requires.tuple_in_w_empty
+ def test_empty_homogeneous_tuples_direct(self):
+ table = self.tables.some_table
+
+ def go(val, expected):
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(val))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, expected)
+
+ go([], [])
+ go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)])
+ go([], [])
+
+ def test_bound_in_scalar_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
+
+ def test_bound_in_scalar_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([2, 3, 4]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ def test_nonempty_in_plus_empty_notin(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([2, 3]))
+ .where(table.c.id.not_in([]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,)])
+
+ def test_empty_in_plus_notempty_notin(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_([]))
+ .where(table.c.id.not_in([2, 3]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [])
+
+ def test_typed_str_in(self):
+ """test related to #7292.
+
+ as a type is given to the bound param, there is no ambiguity
+ to the type of element.
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", type_=String, expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ def test_untyped_str_in(self):
+ """test related to #7292.
+
+ for untyped expression, we look at the types of elements.
+ Test for Sequence to detect tuple in. but not strings or bytes!
+ as always....
+
+ """
+
+ stmt = text(
+ "select id FROM some_table WHERE z IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": ["z2", "z3", "z4"]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)]))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(2,), (3,), (4,)])
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(
+ tuple_(table.c.x, table.c.z).in_(
+ [(2, "z2"), (3, "z3"), (4, "z4")]
+ )
+ )
+ .order_by(table.c.id)
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self):
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(
+ bindparam(
+ "q", type_=TupleType(Integer(), String()), expanding=True
+ )
+ )
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ @testing.requires.tuple_in
+ def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self):
+ # note this becomes ARRAY if we dont use expanding
+ # explicitly right now
+
+ class LikeATuple(collections_abc.Sequence):
+ def __init__(self, *data):
+ self._data = data
+
+ def __iter__(self):
+ return iter(self._data)
+
+ def __getitem__(self, idx):
+ return self._data[idx]
+
+ def __len__(self):
+ return len(self._data)
+
+ stmt = text(
+ "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+ ).bindparams(bindparam("q", expanding=True))
+ self._assert_result(
+ stmt,
+ [(2,), (3,), (4,)],
+ params={
+ "q": [
+ LikeATuple(2, "z2"),
+ LikeATuple(3, "z3"),
+ LikeATuple(4, "z4"),
+ ]
+ },
+ )
+
+ def test_empty_set_against_integer_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ def test_empty_set_against_integer_direct(self):
+ table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+
+ def test_empty_set_against_integer_negation_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.x.not_in(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
+ def test_empty_set_against_integer_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_empty_set_against_string_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.z.in_(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [], params={"q": []})
+
+ def test_empty_set_against_string_direct(self):
+ table = self.tables.some_table
+ stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id)
+ self._assert_result(stmt, [])
+
+ def test_empty_set_against_string_negation_bindparam(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id)
+ .where(table.c.z.not_in(bindparam("q")))
+ .order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
+
+ def test_empty_set_against_string_negation_direct(self):
+ table = self.tables.some_table
+ stmt = (
+ select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id)
+ )
+ self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+ def test_null_in_empty_set_is_false_bindparam(self, connection):
+ stmt = select(
+ case(
+ (
+ null().in_(bindparam("foo", value=())),
+ true(),
+ ),
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+ def test_null_in_empty_set_is_false_direct(self, connection):
+ stmt = select(
+ case(
+ (
+ null().in_([]),
+ true(),
+ ),
+ else_=false(),
+ )
+ )
+ in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+
+class LikeFunctionsTest(fixtures.TablesTest):
+ __backend__ = True
+
+ run_inserts = "once"
+ run_deletes = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "some_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "abcdefg"},
+ {"id": 2, "data": "ab/cdefg"},
+ {"id": 3, "data": "ab%cdefg"},
+ {"id": 4, "data": "ab_cdefg"},
+ {"id": 5, "data": "abcde/fg"},
+ {"id": 6, "data": "abcde%fg"},
+ {"id": 7, "data": "ab#cdefg"},
+ {"id": 8, "data": "ab9cdefg"},
+ {"id": 9, "data": "abcde#fg"},
+ {"id": 10, "data": "abcd9fg"},
+ {"id": 11, "data": None},
+ ],
+ )
+
+ def _test(self, expr, expected):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ rows = {
+ value
+ for value, in conn.execute(select(some_table.c.id).where(expr))
+ }
+
+ eq_(rows, expected)
+
+ def test_startswith_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
+
+ def test_startswith_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c", autoescape=True), {3})
+
+ def test_startswith_sqlexpr(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.startswith(literal_column("'ab%c'")),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ )
+
+ def test_startswith_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab##c", escape="#"), {7})
+
+ def test_startswith_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3})
+ self._test(col.startswith("ab#c", autoescape=True, escape="#"), {7})
+
+ def test_endswith_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg"), {1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ def test_endswith_sqlexpr(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9}
+ )
+
+ def test_endswith_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg", autoescape=True), {6})
+
+ def test_endswith_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e##fg", escape="#"), {9})
+
+ def test_endswith_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6})
+ self._test(col.endswith("e#fg", autoescape=True, escape="#"), {9})
+
+ def test_contains_unescaped(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9})
+
+ def test_contains_autoescape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cde", autoescape=True), {3})
+
+ def test_contains_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b##cde", escape="#"), {7})
+
+ def test_contains_autoescape_escape(self):
+ col = self.tables.some_table.c.data
+ self._test(col.contains("b%cd", autoescape=True, escape="#"), {3})
+ self._test(col.contains("b#cd", autoescape=True, escape="#"), {7})
+
+ @testing.requires.regexp_match
+ def test_not_regexp_match(self):
+ col = self.tables.some_table.c.data
+ self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10})
+
+ @testing.requires.regexp_replace
+ def test_regexp_replace(self):
+ col = self.tables.some_table.c.data
+ self._test(
+ col.regexp_replace("a.cde", "FOO").contains("FOO"), {1, 5, 6, 9}
+ )
+
+ @testing.requires.regexp_match
+ @testing.combinations(
+ ("a.cde", {1, 5, 6, 9}),
+ ("abc", {1, 5, 6, 9, 10}),
+ ("^abc", {1, 5, 6, 9, 10}),
+ ("9cde", {8}),
+ ("^a", set(range(1, 11))),
+ ("(b|c)", set(range(1, 11))),
+ ("^(b|c)", set()),
+ )
+ def test_regexp_match(self, text, expected):
+ col = self.tables.some_table.c.data
+ self._test(col.regexp_match(text), expected)
+
+
+class ComputedColumnTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("computed_columns",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "square",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("side", Integer),
+ Column("area", Integer, Computed("side * side")),
+ Column("perimeter", Integer, Computed("4 * side")),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.square.insert(),
+ [{"id": 1, "side": 10}, {"id": 10, "side": 42}],
+ )
+
+ def test_select_all(self):
+ with config.db.connect() as conn:
+ res = conn.execute(
+ select(text("*"))
+ .select_from(self.tables.square)
+ .order_by(self.tables.square.c.id)
+ ).fetchall()
+ eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)])
+
+ def test_select_columns(self):
+ with config.db.connect() as conn:
+ res = conn.execute(
+ select(
+ self.tables.square.c.area, self.tables.square.c.perimeter
+ )
+ .select_from(self.tables.square)
+ .order_by(self.tables.square.c.id)
+ ).fetchall()
+ eq_(res, [(100, 40), (1764, 168)])
+
+
+class IdentityColumnTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("identity_columns",)
+ run_inserts = "once"
+ run_deletes = "once"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "tbl_a",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(
+ always=True, start=42, nominvalue=True, nomaxvalue=True
+ ),
+ primary_key=True,
+ ),
+ Column("desc", String(100)),
+ )
+ Table(
+ "tbl_b",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0),
+ primary_key=True,
+ ),
+ Column("desc", String(100)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.tbl_a.insert(),
+ [{"desc": "a"}, {"desc": "b"}],
+ )
+ connection.execute(
+ cls.tables.tbl_b.insert(),
+ [{"desc": "a"}, {"desc": "b"}],
+ )
+ connection.execute(
+ cls.tables.tbl_b.insert(),
+ [{"id": 42, "desc": "c"}],
+ )
+
+ def test_select_all(self, connection):
+ res = connection.execute(
+ select(text("*"))
+ .select_from(self.tables.tbl_a)
+ .order_by(self.tables.tbl_a.c.id)
+ ).fetchall()
+ eq_(res, [(42, "a"), (43, "b")])
+
+ res = connection.execute(
+ select(text("*"))
+ .select_from(self.tables.tbl_b)
+ .order_by(self.tables.tbl_b.c.id)
+ ).fetchall()
+ eq_(res, [(-5, "b"), (0, "a"), (42, "c")])
+
+ def test_select_columns(self, connection):
+
+ res = connection.execute(
+ select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id)
+ ).fetchall()
+ eq_(res, [(42,), (43,)])
+
+ @testing.requires.identity_columns_standard
+ def test_insert_always_error(self, connection):
+ def fn():
+ connection.execute(
+ self.tables.tbl_a.insert(),
+ [{"id": 200, "desc": "a"}],
+ )
+
+ assert_raises((DatabaseError, ProgrammingError), fn)
+
+
+class IdentityAutoincrementTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("autoincrement_without_sequence",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "tbl",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Identity(),
+ primary_key=True,
+ autoincrement=True,
+ ),
+ Column("desc", String(100)),
+ )
+
+ def test_autoincrement_with_identity(self, connection):
+ res = connection.execute(self.tables.tbl.insert(), {"desc": "row"})
+ res = connection.execute(self.tables.tbl.select()).first()
+ eq_(res, (1, "row"))
+
+
+class ExistsTest(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "stuff",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.stuff.insert(),
+ [
+ {"id": 1, "data": "some data"},
+ {"id": 2, "data": "some data"},
+ {"id": 3, "data": "some data"},
+ {"id": 4, "data": "some other data"},
+ ],
+ )
+
+ def test_select_exists(self, connection):
+ stuff = self.tables.stuff
+ eq_(
+ connection.execute(
+ select(literal(1)).where(
+ exists().where(stuff.c.data == "some data")
+ )
+ ).fetchall(),
+ [(1,)],
+ )
+
+ def test_select_exists_false(self, connection):
+ stuff = self.tables.stuff
+ eq_(
+ connection.execute(
+ select(literal(1)).where(
+ exists().where(stuff.c.data == "no data")
+ )
+ ).fetchall(),
+ [],
+ )
+
+
+class DistinctOnTest(AssertsCompiledSQL, fixtures.TablesTest):
+ __backend__ = True
+
+ @testing.fails_if(testing.requires.supports_distinct_on)
+ def test_distinct_on(self):
+ stm = select("*").distinct(column("q")).select_from(table("foo"))
+ with testing.expect_deprecated(
+ "DISTINCT ON is currently supported only by the PostgreSQL "
+ ):
+ self.assert_compile(stm, "SELECT DISTINCT * FROM foo")
+
+
+class IsOrIsNotDistinctFromTest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = ("supports_is_distinct_from",)
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "is_distinct_test",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("col_a", Integer, nullable=True),
+ Column("col_b", Integer, nullable=True),
+ )
+
+ @testing.combinations(
+ ("both_int_different", 0, 1, 1),
+ ("both_int_same", 1, 1, 0),
+ ("one_null_first", None, 1, 1),
+ ("one_null_second", 0, None, 1),
+ ("both_null", None, None, 0),
+ id_="iaaa",
+ argnames="col_a_value, col_b_value, expected_row_count_for_is",
+ )
+ def test_is_or_is_not_distinct_from(
+ self, col_a_value, col_b_value, expected_row_count_for_is, connection
+ ):
+ tbl = self.tables.is_distinct_test
+
+ connection.execute(
+ tbl.insert(),
+ [{"id": 1, "col_a": col_a_value, "col_b": col_b_value}],
+ )
+
+ result = connection.execute(
+ tbl.select().where(tbl.c.col_a.is_distinct_from(tbl.c.col_b))
+ ).fetchall()
+ eq_(
+ len(result),
+ expected_row_count_for_is,
+ )
+
+ expected_row_count_for_is_not = (
+ 1 if expected_row_count_for_is == 0 else 0
+ )
+ result = connection.execute(
+ tbl.select().where(tbl.c.col_a.is_not_distinct_from(tbl.c.col_b))
+ ).fetchall()
+ eq_(
+ len(result),
+ expected_row_count_for_is_not,
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
new file mode 100644
index 0000000..d6747d2
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -0,0 +1,282 @@
+from .. import config
+from .. import fixtures
+from ..assertions import eq_
+from ..assertions import is_true
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import inspect
+from ... import Integer
+from ... import MetaData
+from ... import Sequence
+from ... import String
+from ... import testing
+
+
+class SequenceTest(fixtures.TablesTest):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ run_create_tables = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "seq_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "seq_opt_pk",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("tab_id_seq", data_type=Integer, optional=True),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ )
+
+ Table(
+ "seq_no_returning",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_id_seq"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ )
+
+ if testing.requires.schemas.enabled:
+ Table(
+ "seq_no_returning_sch",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema=config.test_schema),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=False,
+ schema=config.test_schema,
+ )
+
+ def test_insert_roundtrip(self, connection):
+ connection.execute(self.tables.seq_pk.insert(), dict(data="some data"))
+ self._assert_round_trip(self.tables.seq_pk, connection)
+
+ def test_insert_lastrowid(self, connection):
+ r = connection.execute(
+ self.tables.seq_pk.insert(), dict(data="some data")
+ )
+ eq_(
+ r.inserted_primary_key, (testing.db.dialect.default_sequence_base,)
+ )
+
+ def test_nextval_direct(self, connection):
+ r = connection.execute(self.tables.seq_pk.c.id.default)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
+ @requirements.sequences_optional
+ def test_optional_seq(self, connection):
+ r = connection.execute(
+ self.tables.seq_opt_pk.insert(), dict(data="some data")
+ )
+ eq_(r.inserted_primary_key, (1,))
+
+ def _assert_round_trip(self, table, conn):
+ row = conn.execute(table.select()).first()
+ eq_(row, (testing.db.dialect.default_sequence_base, "some data"))
+
+ def test_insert_roundtrip_no_implicit_returning(self, connection):
+ connection.execute(
+ self.tables.seq_no_returning.insert(), dict(data="some data")
+ )
+ self._assert_round_trip(self.tables.seq_no_returning, connection)
+
+ @testing.combinations((True,), (False,), argnames="implicit_returning")
+ @testing.requires.schemas
+ def test_insert_roundtrip_translate(self, connection, implicit_returning):
+
+ seq_no_returning = Table(
+ "seq_no_returning_sch",
+ MetaData(),
+ Column(
+ "id",
+ Integer,
+ Sequence("noret_sch_id_seq", schema="alt_schema"),
+ primary_key=True,
+ ),
+ Column("data", String(50)),
+ implicit_returning=implicit_returning,
+ schema="alt_schema",
+ )
+
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+ connection.execute(seq_no_returning.insert(), dict(data="some data"))
+ self._assert_round_trip(seq_no_returning, connection)
+
+ @testing.requires.schemas
+ def test_nextval_direct_schema_translate(self, connection):
+ seq = Sequence("noret_sch_id_seq", schema="alt_schema")
+ connection = connection.execution_options(
+ schema_translate_map={"alt_schema": config.test_schema}
+ )
+
+ r = connection.execute(seq)
+ eq_(r, testing.db.dialect.default_sequence_base)
+
+
+class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ def test_literal_binds_inline_compile(self, connection):
+ table = Table(
+ "x",
+ MetaData(),
+ Column("y", Integer, Sequence("y_seq")),
+ Column("q", Integer),
+ )
+
+ stmt = table.insert().values(q=5)
+
+ seq_nextval = connection.dialect.statement_compiler(
+ statement=None, dialect=connection.dialect
+ ).visit_sequence(Sequence("y_seq"))
+ self.assert_compile(
+ stmt,
+ "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,),
+ literal_binds=True,
+ dialect=connection.dialect,
+ )
+
+
+class HasSequenceTest(fixtures.TablesTest):
+ run_deletes = None
+
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Sequence("user_id_seq", metadata=metadata)
+ Sequence(
+ "other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True
+ )
+ if testing.requires.schemas.enabled:
+ Sequence(
+ "user_id_seq", schema=config.test_schema, metadata=metadata
+ )
+ Sequence(
+ "schema_seq", schema=config.test_schema, metadata=metadata
+ )
+ Table(
+ "user_id_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ )
+
+ def test_has_sequence(self, connection):
+ eq_(
+ inspect(connection).has_sequence("user_id_seq"),
+ True,
+ )
+
+ def test_has_sequence_other_object(self, connection):
+ eq_(
+ inspect(connection).has_sequence("user_id_table"),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_schema(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "user_id_seq", schema=config.test_schema
+ ),
+ True,
+ )
+
+ def test_has_sequence_neg(self, connection):
+ eq_(
+ inspect(connection).has_sequence("some_sequence"),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_schemas_neg(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "some_sequence", schema=config.test_schema
+ ),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_default_not_in_remote(self, connection):
+ eq_(
+ inspect(connection).has_sequence(
+ "other_sequence", schema=config.test_schema
+ ),
+ False,
+ )
+
+ @testing.requires.schemas
+ def test_has_sequence_remote_not_in_default(self, connection):
+ eq_(
+ inspect(connection).has_sequence("schema_seq"),
+ False,
+ )
+
+ def test_get_sequence_names(self, connection):
+ exp = {"other_seq", "user_id_seq"}
+
+ res = set(inspect(connection).get_sequence_names())
+ is_true(res.intersection(exp) == exp)
+ is_true("schema_seq" not in res)
+
+ @testing.requires.schemas
+ def test_get_sequence_names_no_sequence_schema(self, connection):
+ eq_(
+ inspect(connection).get_sequence_names(
+ schema=config.test_schema_2
+ ),
+ [],
+ )
+
+ @testing.requires.schemas
+ def test_get_sequence_names_sequences_schema(self, connection):
+ eq_(
+ sorted(
+ inspect(connection).get_sequence_names(
+ schema=config.test_schema
+ )
+ ),
+ ["schema_seq", "user_id_seq"],
+ )
+
+
+class HasSequenceTestEmpty(fixtures.TestBase):
+ __requires__ = ("sequences",)
+ __backend__ = True
+
+ def test_get_sequence_names_no_sequence(self, connection):
+ eq_(
+ inspect(connection).get_sequence_names(),
+ [],
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
new file mode 100644
index 0000000..b96350e
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -0,0 +1,1508 @@
+# coding: utf-8
+
+import datetime
+import decimal
+import json
+import re
+
+from .. import config
+from .. import engines
+from .. import fixtures
+from .. import mock
+from ..assertions import eq_
+from ..assertions import is_
+from ..config import requirements
+from ..schema import Column
+from ..schema import Table
+from ... import and_
+from ... import BigInteger
+from ... import bindparam
+from ... import Boolean
+from ... import case
+from ... import cast
+from ... import Date
+from ... import DateTime
+from ... import Float
+from ... import Integer
+from ... import JSON
+from ... import literal
+from ... import MetaData
+from ... import null
+from ... import Numeric
+from ... import select
+from ... import String
+from ... import testing
+from ... import Text
+from ... import Time
+from ... import TIMESTAMP
+from ... import TypeDecorator
+from ... import Unicode
+from ... import UnicodeText
+from ... import util
+from ...orm import declarative_base
+from ...orm import Session
+from ...sql.sqltypes import LargeBinary
+from ...sql.sqltypes import PickleType
+from ...util import compat
+from ...util import u
+
+
+class _LiteralRoundTripFixture(object):
+ supports_whereclause = True
+
+ @testing.fixture
+ def literal_round_trip(self, metadata, connection):
+ """test literal rendering"""
+
+ # for literal, we test the literal render in an INSERT
+ # into a typed column. we can then SELECT it back as its
+ # official type; ideally we'd be able to use CAST here
+ # but MySQL in particular can't CAST fully
+
+ def run(type_, input_, output, filter_=None):
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+
+ for value in input_:
+ ins = (
+ t.insert()
+ .values(x=literal(value, type_))
+ .compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
+ )
+ connection.execute(ins)
+
+ if self.supports_whereclause:
+ stmt = t.select().where(t.c.x == literal(value))
+ else:
+ stmt = t.select()
+
+ stmt = stmt.compile(
+ dialect=testing.db.dialect,
+ compile_kwargs=dict(literal_binds=True),
+ )
+ for row in connection.execute(stmt):
+ value = row[0]
+ if filter_ is not None:
+ value = filter_(value)
+ assert value in output
+
+ return run
+
+
+class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase):
+ __requires__ = ("unicode_data",)
+
+ data = u(
+ "Alors vous imaginez ma 🐍 surprise, au lever du jour, "
+ "quand une drôle de petite 🐍 voix m’a réveillé. Elle "
+ "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »"
+ )
+
+ @property
+ def supports_whereclause(self):
+ return config.requirements.expressions_against_unbounded_text.enabled
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "unicode_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("unicode_data", cls.datatype),
+ )
+
+ def test_round_trip(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": self.data}
+ )
+
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+
+ eq_(row, (self.data,))
+ assert isinstance(row[0], util.text_type)
+
+ def test_round_trip_executemany(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(),
+ [{"id": i, "unicode_data": self.data} for i in range(1, 4)],
+ )
+
+ rows = connection.execute(
+ select(unicode_table.c.unicode_data)
+ ).fetchall()
+ eq_(rows, [(self.data,) for i in range(1, 4)])
+ for row in rows:
+ assert isinstance(row[0], util.text_type)
+
+ def _test_null_strings(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": None}
+ )
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+ eq_(row, (None,))
+
+ def _test_empty_strings(self, connection):
+ unicode_table = self.tables.unicode_table
+
+ connection.execute(
+ unicode_table.insert(), {"id": 1, "unicode_data": u("")}
+ )
+ row = connection.execute(select(unicode_table.c.unicode_data)).first()
+ eq_(row, (u(""),))
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(self.datatype, [self.data], [self.data])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+
+class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest):
+ __requires__ = ("unicode_data",)
+ __backend__ = True
+
+ datatype = Unicode(255)
+
+ @requirements.empty_strings_varchar
+ def test_empty_strings_varchar(self, connection):
+ self._test_empty_strings(connection)
+
+ def test_null_strings_varchar(self, connection):
+ self._test_null_strings(connection)
+
+
+class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest):
+ __requires__ = "unicode_data", "text_type"
+ __backend__ = True
+
+ datatype = UnicodeText()
+
+ @requirements.empty_strings_text
+ def test_empty_strings_text(self, connection):
+ self._test_empty_strings(connection)
+
+ def test_null_strings_text(self, connection):
+ self._test_null_strings(connection)
+
+
+class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("binary_literals",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "binary_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("binary_data", LargeBinary),
+ Column("pickle_data", PickleType),
+ )
+
+ def test_binary_roundtrip(self, connection):
+ binary_table = self.tables.binary_table
+
+ connection.execute(
+ binary_table.insert(), {"id": 1, "binary_data": b"this is binary"}
+ )
+ row = connection.execute(select(binary_table.c.binary_data)).first()
+ eq_(row, (b"this is binary",))
+
+ def test_pickle_roundtrip(self, connection):
+ binary_table = self.tables.binary_table
+
+ connection.execute(
+ binary_table.insert(),
+ {"id": 1, "pickle_data": {"foo": [1, 2, 3], "bar": "bat"}},
+ )
+ row = connection.execute(select(binary_table.c.pickle_data)).first()
+ eq_(row, ({"foo": [1, 2, 3], "bar": "bat"},))
+
+
+class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("text_type",)
+ __backend__ = True
+
+ @property
+ def supports_whereclause(self):
+ return config.requirements.expressions_against_unbounded_text.enabled
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "text_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("text_data", Text),
+ )
+
+ def test_text_roundtrip(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(
+ text_table.insert(), {"id": 1, "text_data": "some text"}
+ )
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, ("some text",))
+
+ @testing.requires.empty_strings_text
+ def test_text_empty_strings(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(text_table.insert(), {"id": 1, "text_data": ""})
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, ("",))
+
+ def test_text_null_strings(self, connection):
+ text_table = self.tables.text_table
+
+ connection.execute(text_table.insert(), {"id": 1, "text_data": None})
+ row = connection.execute(select(text_table.c.text_data)).first()
+ eq_(row, (None,))
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(Text, ["some text"], ["some text"])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+ def test_literal_quoting(self, literal_round_trip):
+ data = """some 'text' hey "hi there" that's text"""
+ literal_round_trip(Text, [data], [data])
+
+ def test_literal_backslashes(self, literal_round_trip):
+ data = r"backslash one \ backslash two \\ end"
+ literal_round_trip(Text, [data], [data])
+
+ def test_literal_percentsigns(self, literal_round_trip):
+ data = r"percent % signs %% percent"
+ literal_round_trip(Text, [data], [data])
+
+
+class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @requirements.unbounded_varchar
+ def test_nolength_string(self):
+ metadata = MetaData()
+ foo = Table("foo", metadata, Column("one", String))
+
+ foo.create(config.db)
+ foo.drop(config.db)
+
+ def test_literal(self, literal_round_trip):
+ # note that in Python 3, this invokes the Unicode
+ # datatype for the literal part because all strings are unicode
+ literal_round_trip(String(40), ["some text"], ["some text"])
+
+ def test_literal_non_ascii(self, literal_round_trip):
+ literal_round_trip(
+ String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")]
+ )
+
+ def test_literal_quoting(self, literal_round_trip):
+ data = """some 'text' hey "hi there" that's text"""
+ literal_round_trip(String(40), [data], [data])
+
+ def test_literal_backslashes(self, literal_round_trip):
+ data = r"backslash one \ backslash two \\ end"
+ literal_round_trip(String(40), [data], [data])
+
+
+class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
+ compare = None
+
+ @classmethod
+ def define_tables(cls, metadata):
+ class Decorated(TypeDecorator):
+ impl = cls.datatype
+ cache_ok = True
+
+ Table(
+ "date_table",
+ metadata,
+ Column(
+ "id", Integer, primary_key=True, test_needs_autoincrement=True
+ ),
+ Column("date_data", cls.datatype),
+ Column("decorated_date_data", Decorated),
+ )
+
+ @testing.requires.datetime_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
+ def test_round_trip(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"id": 1, "date_data": self.data}
+ )
+
+ row = connection.execute(select(date_table.c.date_data)).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
+ def test_round_trip_decorated(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(
+ date_table.insert(), {"id": 1, "decorated_date_data": self.data}
+ )
+
+ row = connection.execute(
+ select(date_table.c.decorated_date_data)
+ ).first()
+
+ compare = self.compare or self.data
+ eq_(row, (compare,))
+ assert isinstance(row[0], type(compare))
+
+ def test_null(self, connection):
+ date_table = self.tables.date_table
+
+ connection.execute(date_table.insert(), {"id": 1, "date_data": None})
+
+ row = connection.execute(select(date_table.c.date_data)).first()
+ eq_(row, (None,))
+
+ @testing.requires.datetime_literals
+ def test_literal(self, literal_round_trip):
+ compare = self.compare or self.data
+ literal_round_trip(self.datatype, [self.data], [compare])
+
+ @testing.requires.standalone_null_binds_whereclause
+ def test_null_bound_comparison(self):
+ # this test is based on an Oracle issue observed in #4886.
+ # passing NULL for an expression that needs to be interpreted as
+ # a certain type, does the DBAPI have the info it needs to do this.
+ date_table = self.tables.date_table
+ with config.db.begin() as conn:
+ result = conn.execute(
+ date_table.insert(), {"id": 1, "date_data": self.data}
+ )
+ id_ = result.inserted_primary_key[0]
+ stmt = select(date_table.c.id).where(
+ case(
+ (
+ bindparam("foo", type_=self.datatype) != None,
+ bindparam("foo", type_=self.datatype),
+ ),
+ else_=date_table.c.date_data,
+ )
+ == date_table.c.date_data
+ )
+
+ row = conn.execute(stmt, {"foo": None}).first()
+ eq_(row[0], id_)
+
+
+class DateTimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+
+
+class DateTimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_timezone",)
+ __backend__ = True
+ datatype = DateTime(timezone=True)
+ data = datetime.datetime(
+ 2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc
+ )
+
+
+class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_microseconds",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
+
+class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("timestamp_microseconds",)
+ __backend__ = True
+ datatype = TIMESTAMP
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396)
+
+ @testing.requires.timestamp_microseconds_implicit_bound
+ def test_select_direct(self, connection):
+ result = connection.scalar(select(literal(self.data)))
+ eq_(result, self.data)
+
+
+class TimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time",)
+ __backend__ = True
+ datatype = Time
+ data = datetime.time(12, 57, 18)
+
+
+class TimeTZTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_timezone",)
+ __backend__ = True
+ datatype = Time(timezone=True)
+ data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc)
+
+
+class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("time_microseconds",)
+ __backend__ = True
+ datatype = Time
+ data = datetime.time(12, 57, 18, 396)
+
+
+class DateTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("date",)
+ __backend__ = True
+ datatype = Date
+ data = datetime.date(2012, 10, 15)
+
+
+class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = "date", "date_coerces_from_datetime"
+ __backend__ = True
+ datatype = Date
+ data = datetime.datetime(2012, 10, 15, 12, 57, 18)
+ compare = datetime.date(2012, 10, 15)
+
+
+class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("datetime_historic",)
+ __backend__ = True
+ datatype = DateTime
+ data = datetime.datetime(1850, 11, 10, 11, 52, 35)
+
+
+class DateHistoricTest(_DateFixture, fixtures.TablesTest):
+ __requires__ = ("date_historic",)
+ __backend__ = True
+ datatype = Date
+ data = datetime.date(1727, 4, 1)
+
+
+class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ def test_literal(self, literal_round_trip):
+ literal_round_trip(Integer, [5], [5])
+
+ def test_huge_int(self, integer_round_trip):
+ integer_round_trip(BigInteger, 1376537018368127)
+
+ @testing.fixture
+ def integer_round_trip(self, metadata, connection):
+ def run(datatype, data):
+ int_table = Table(
+ "integer_table",
+ metadata,
+ Column(
+ "id",
+ Integer,
+ primary_key=True,
+ test_needs_autoincrement=True,
+ ),
+ Column("integer_data", datatype),
+ )
+
+ metadata.create_all(config.db)
+
+ connection.execute(
+ int_table.insert(), {"id": 1, "integer_data": data}
+ )
+
+ row = connection.execute(select(int_table.c.integer_data)).first()
+
+ eq_(row, (data,))
+
+ if util.py3k:
+ assert isinstance(row[0], int)
+ else:
+ assert isinstance(row[0], (long, int)) # noqa
+
+ return run
+
+
+class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @testing.fixture
+ def string_as_int(self):
+ class StringAsInt(TypeDecorator):
+ impl = String(50)
+ cache_ok = True
+
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
+
+ def column_expression(self, col):
+ return cast(col, Integer)
+
+ def bind_expression(self, col):
+ return cast(col, String(50))
+
+ return StringAsInt()
+
+ def test_special_type(self, metadata, connection, string_as_int):
+
+ type_ = string_as_int
+
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+
+ connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ eq_(result, {1, 2, 3})
+
+ result = {
+ row[0] for row in connection.execute(t.select().where(t.c.x == 2))
+ }
+ eq_(result, {2})
+
+
+class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
+ __backend__ = True
+
+ @testing.fixture
+ def do_numeric_test(self, metadata, connection):
+ @testing.emits_warning(
+ r".*does \*not\* support Decimal objects natively"
+ )
+ def run(type_, input_, output, filter_=None, check_scale=False):
+ t = Table("t", metadata, Column("x", type_))
+ t.create(connection)
+ connection.execute(t.insert(), [{"x": x} for x in input_])
+
+ result = {row[0] for row in connection.execute(t.select())}
+ output = set(output)
+ if filter_:
+ result = set(filter_(x) for x in result)
+ output = set(filter_(x) for x in output)
+ eq_(result, output)
+ if check_scale:
+ eq_([str(x) for x in result], [str(x) for x in output])
+
+ return run
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_render_literal_numeric(self, literal_round_trip):
+ literal_round_trip(
+ Numeric(precision=8, scale=4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [decimal.Decimal("15.7563")],
+ )
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_render_literal_numeric_asfloat(self, literal_round_trip):
+ literal_round_trip(
+ Numeric(precision=8, scale=4, asdecimal=False),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ )
+
+ def test_render_literal_float(self, literal_round_trip):
+ literal_round_trip(
+ Float(4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
+ )
+
+ @testing.requires.precision_generic_float_type
+ def test_float_custom_scale(self, do_numeric_test):
+ do_numeric_test(
+ Float(None, decimal_return_scale=7, asdecimal=True),
+ [15.7563827, decimal.Decimal("15.7563827")],
+ [decimal.Decimal("15.7563827")],
+ check_scale=True,
+ )
+
+ def test_numeric_as_decimal(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4),
+ [15.7563, decimal.Decimal("15.7563")],
+ [decimal.Decimal("15.7563")],
+ )
+
+ def test_numeric_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4, asdecimal=False),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ )
+
+ @testing.requires.infinity_floats
+ def test_infinity_floats(self, do_numeric_test):
+ """test for #977, #7283"""
+
+ do_numeric_test(
+ Float(None),
+ [float("inf")],
+ [float("inf")],
+ )
+
+ @testing.requires.fetch_null_from_numeric
+ def test_numeric_null_as_decimal(self, do_numeric_test):
+ do_numeric_test(Numeric(precision=8, scale=4), [None], [None])
+
+ @testing.requires.fetch_null_from_numeric
+ def test_numeric_null_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Numeric(precision=8, scale=4, asdecimal=False), [None], [None]
+ )
+
+ @testing.requires.floats_to_four_decimals
+ def test_float_as_decimal(self, do_numeric_test):
+ do_numeric_test(
+ Float(precision=8, asdecimal=True),
+ [15.7563, decimal.Decimal("15.7563"), None],
+ [decimal.Decimal("15.7563"), None],
+ filter_=lambda n: n is not None and round(n, 4) or None,
+ )
+
+ def test_float_as_float(self, do_numeric_test):
+ do_numeric_test(
+ Float(precision=8),
+ [15.7563, decimal.Decimal("15.7563")],
+ [15.7563],
+ filter_=lambda n: n is not None and round(n, 5) or None,
+ )
+
+ def test_float_coerce_round_trip(self, connection):
+ expr = 15.7563
+
+ val = connection.scalar(select(literal(expr)))
+ eq_(val, expr)
+
+ # this does not work in MySQL, see #4036, however we choose not
+ # to render CAST unconditionally since this is kind of an edge case.
+
+ @testing.requires.implicit_decimal_binds
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_decimal_coerce_round_trip(self, connection):
+ expr = decimal.Decimal("15.7563")
+
+ val = connection.scalar(select(literal(expr)))
+ eq_(val, expr)
+
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_decimal_coerce_round_trip_w_cast(self, connection):
+ expr = decimal.Decimal("15.7563")
+
+ val = connection.scalar(select(cast(expr, Numeric(10, 4))))
+ eq_(val, expr)
+
+ @testing.requires.precision_numerics_general
+ def test_precision_decimal(self, do_numeric_test):
+ numbers = set(
+ [
+ decimal.Decimal("54.234246451650"),
+ decimal.Decimal("0.004354"),
+ decimal.Decimal("900.0"),
+ ]
+ )
+
+ do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers)
+
+ @testing.requires.precision_numerics_enotation_large
+ def test_enotation_decimal(self, do_numeric_test):
+ """test exceedingly small decimals.
+
+ Decimal reports values with E notation when the exponent
+ is greater than 6.
+
+ """
+
+ numbers = set(
+ [
+ decimal.Decimal("1E-2"),
+ decimal.Decimal("1E-3"),
+ decimal.Decimal("1E-4"),
+ decimal.Decimal("1E-5"),
+ decimal.Decimal("1E-6"),
+ decimal.Decimal("1E-7"),
+ decimal.Decimal("1E-8"),
+ decimal.Decimal("0.01000005940696"),
+ decimal.Decimal("0.00000005940696"),
+ decimal.Decimal("0.00000000000696"),
+ decimal.Decimal("0.70000000000696"),
+ decimal.Decimal("696E-12"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers)
+
+ @testing.requires.precision_numerics_enotation_large
+ def test_enotation_decimal_large(self, do_numeric_test):
+ """test exceedingly large decimals."""
+
+ numbers = set(
+ [
+ decimal.Decimal("4E+8"),
+ decimal.Decimal("5748E+15"),
+ decimal.Decimal("1.521E+15"),
+ decimal.Decimal("00000000000000.1E+12"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers)
+
+ @testing.requires.precision_numerics_many_significant_digits
+ def test_many_significant_digits(self, do_numeric_test):
+ numbers = set(
+ [
+ decimal.Decimal("31943874831932418390.01"),
+ decimal.Decimal("319438950232418390.273596"),
+ decimal.Decimal("87673.594069654243"),
+ ]
+ )
+ do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers)
+
+ @testing.requires.precision_numerics_retains_significant_digits
+ def test_numeric_no_decimal(self, do_numeric_test):
+ numbers = set([decimal.Decimal("1.000")])
+ do_numeric_test(
+ Numeric(precision=5, scale=3), numbers, numbers, check_scale=True
+ )
+
+
+class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "boolean_table",
+ metadata,
+ Column("id", Integer, primary_key=True, autoincrement=False),
+ Column("value", Boolean),
+ Column("unconstrained_value", Boolean(create_constraint=False)),
+ )
+
+ def test_render_literal_bool(self, literal_round_trip):
+ literal_round_trip(Boolean(), [True, False], [True, False])
+
+ def test_round_trip(self, connection):
+ boolean_table = self.tables.boolean_table
+
+ connection.execute(
+ boolean_table.insert(),
+ {"id": 1, "value": True, "unconstrained_value": False},
+ )
+
+ row = connection.execute(
+ select(boolean_table.c.value, boolean_table.c.unconstrained_value)
+ ).first()
+
+ eq_(row, (True, False))
+ assert isinstance(row[0], bool)
+
+ @testing.requires.nullable_booleans
+ def test_null(self, connection):
+ boolean_table = self.tables.boolean_table
+
+ connection.execute(
+ boolean_table.insert(),
+ {"id": 1, "value": None, "unconstrained_value": None},
+ )
+
+ row = connection.execute(
+ select(boolean_table.c.value, boolean_table.c.unconstrained_value)
+ ).first()
+
+ eq_(row, (None, None))
+
+ def test_whereclause(self):
+ # testing "WHERE <column>" renders a compatible expression
+ boolean_table = self.tables.boolean_table
+
+ with config.db.begin() as conn:
+ conn.execute(
+ boolean_table.insert(),
+ [
+ {"id": 1, "value": True, "unconstrained_value": True},
+ {"id": 2, "value": False, "unconstrained_value": False},
+ ],
+ )
+
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(boolean_table.c.value)
+ ),
+ 1,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(
+ boolean_table.c.unconstrained_value
+ )
+ ),
+ 1,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(~boolean_table.c.value)
+ ),
+ 2,
+ )
+ eq_(
+ conn.scalar(
+ select(boolean_table.c.id).where(
+ ~boolean_table.c.unconstrained_value
+ )
+ ),
+ 2,
+ )
+
+
+class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest):
+ __requires__ = ("json_type",)
+ __backend__ = True
+
+ datatype = JSON
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype, nullable=False),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
+
+ def test_round_trip_data1(self, connection):
+ self._test_round_trip({"key1": "value1", "key2": "value2"}, connection)
+
+ def _test_round_trip(self, data_element, connection):
+ data_table = self.tables.data_table
+
+ connection.execute(
+ data_table.insert(),
+ {"id": 1, "name": "row1", "data": data_element},
+ )
+
+ row = connection.execute(select(data_table.c.data)).first()
+
+ eq_(row, (data_element,))
+
+ def _index_fixtures(include_comparison):
+
+ if include_comparison:
+ # basically SQL Server and MariaDB can kind of do json
+ # comparison, MySQL, PG and SQLite can't. not worth it.
+ json_elements = []
+ else:
+ json_elements = [
+ ("json", {"foo": "bar"}),
+ ("json", ["one", "two", "three"]),
+ (None, {"foo": "bar"}),
+ (None, ["one", "two", "three"]),
+ ]
+
+ elements = [
+ ("boolean", True),
+ ("boolean", False),
+ ("boolean", None),
+ ("string", "some string"),
+ ("string", None),
+ ("string", util.u("réve illé")),
+ (
+ "string",
+ util.u("réve🐍 illé"),
+ testing.requires.json_index_supplementary_unicode_element,
+ ),
+ ("integer", 15),
+ ("integer", 1),
+ ("integer", 0),
+ ("integer", None),
+ ("float", 28.5),
+ ("float", None),
+ (
+ "float",
+ 1234567.89,
+ ),
+ ("numeric", 1234567.89),
+ # this one "works" because the float value you see here is
+ # lost immediately to floating point stuff
+ ("numeric", 99998969694839.983485848, requirements.python3),
+ ("numeric", 99939.983485848, requirements.python3),
+ ("_decimal", decimal.Decimal("1234567.89")),
+ (
+ "_decimal",
+ decimal.Decimal("99998969694839.983485848"),
+ # fails on SQLite and MySQL (non-mariadb)
+ requirements.cast_precision_numerics_many_significant_digits,
+ ),
+ (
+ "_decimal",
+ decimal.Decimal("99939.983485848"),
+ ),
+ ] + json_elements
+
+ def decorate(fn):
+ fn = testing.combinations(id_="sa", *elements)(fn)
+
+ return fn
+
+ return decorate
+
+ def _json_value_insert(self, connection, datatype, value, data_element):
+ data_table = self.tables.data_table
+ if datatype == "_decimal":
+
+ # Python's builtin json serializer basically doesn't support
+ # Decimal objects without implicit float conversion period.
+ # users can otherwise use simplejson which supports
+ # precision decimals
+
+ # https://bugs.python.org/issue16535
+
+ # inserting as strings to avoid a new fixture around the
+ # dialect which would have idiosyncrasies for different
+ # backends.
+
+ class DecimalEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, decimal.Decimal):
+ return str(o)
+ return super(DecimalEncoder, self).default(o)
+
+ json_data = json.dumps(data_element, cls=DecimalEncoder)
+
+ # take the quotes out. yup, there is *literally* no other
+ # way to get Python's json.dumps() to put all the digits in
+ # the string
+ json_data = re.sub(r'"(%s)"' % str(value), str(value), json_data)
+
+ datatype = "numeric"
+
+ connection.execute(
+ data_table.insert().values(
+ name="row1",
+ # to pass the string directly to every backend, including
+ # PostgreSQL which needs the value to be CAST as JSON
+ # both in the SQL as well as at the prepared statement
+ # level for asyncpg, while at the same time MySQL
+ # doesn't even support CAST for JSON, here we are
+ # sending the string embedded in the SQL without using
+ # a parameter.
+ data=bindparam(None, json_data, literal_execute=True),
+ nulldata=bindparam(None, json_data, literal_execute=True),
+ ),
+ )
+ else:
+ connection.execute(
+ data_table.insert(),
+ {
+ "name": "row1",
+ "data": data_element,
+ "nulldata": data_element,
+ },
+ )
+
+ p_s = None
+
+ if datatype:
+ if datatype == "numeric":
+ a, b = str(value).split(".")
+ s = len(b)
+ p = len(a) + s
+
+ if isinstance(value, decimal.Decimal):
+ compare_value = value
+ else:
+ compare_value = decimal.Decimal(str(value))
+
+ p_s = (p, s)
+ else:
+ compare_value = value
+ else:
+ compare_value = value
+
+ return datatype, compare_value, p_s
+
+ @_index_fixtures(False)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_index_typed_access(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": value}
+
+ with config.db.begin() as conn:
+
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data["key1"]
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ roundtrip = conn.scalar(select(expr))
+ eq_(roundtrip, compare_value)
+ if util.py3k: # skip py2k to avoid comparing unicode to str etc.
+ is_(type(roundtrip), type(compare_value))
+
+ @_index_fixtures(True)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_index_typed_comparison(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": value}
+
+ with config.db.begin() as conn:
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data["key1"]
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ row = conn.execute(
+ select(expr).where(expr == compare_value)
+ ).first()
+
+ # make sure we get a row even if value is None
+ eq_(row, (compare_value,))
+
+ @_index_fixtures(True)
+ @testing.emits_warning(r".*does \*not\* support Decimal objects natively")
+ def test_path_typed_comparison(self, datatype, value):
+ data_table = self.tables.data_table
+ data_element = {"key1": {"subkey1": value}}
+ with config.db.begin() as conn:
+
+ datatype, compare_value, p_s = self._json_value_insert(
+ conn, datatype, value, data_element
+ )
+
+ expr = data_table.c.data[("key1", "subkey1")]
+
+ if datatype:
+ if datatype == "numeric" and p_s:
+ expr = expr.as_numeric(*p_s)
+ else:
+ expr = getattr(expr, "as_%s" % datatype)()
+
+ row = conn.execute(
+ select(expr).where(expr == compare_value)
+ ).first()
+
+ # make sure we get a row even if value is None
+ eq_(row, (compare_value,))
+
+ @testing.combinations(
+ (True,),
+ (False,),
+ (None,),
+ (15,),
+ (0,),
+ (-1,),
+ (-1.0,),
+ (15.052,),
+ ("a string",),
+ (util.u("réve illé"),),
+ (util.u("réve🐍 illé"),),
+ )
+ def test_single_element_round_trip(self, element):
+ data_table = self.tables.data_table
+ data_element = element
+ with config.db.begin() as conn:
+ conn.execute(
+ data_table.insert(),
+ {
+ "name": "row1",
+ "data": data_element,
+ "nulldata": data_element,
+ },
+ )
+
+ row = conn.execute(
+ select(data_table.c.data, data_table.c.nulldata)
+ ).first()
+
+ eq_(row, (data_element, data_element))
+
+ def test_round_trip_custom_json(self):
+ data_table = self.tables.data_table
+ data_element = {"key1": "data1"}
+
+ js = mock.Mock(side_effect=json.dumps)
+ jd = mock.Mock(side_effect=json.loads)
+ engine = engines.testing_engine(
+ options=dict(json_serializer=js, json_deserializer=jd)
+ )
+
+ # support sqlite :memory: database...
+ data_table.create(engine, checkfirst=True)
+ with engine.begin() as conn:
+ conn.execute(
+ data_table.insert(), {"name": "row1", "data": data_element}
+ )
+ row = conn.execute(select(data_table.c.data)).first()
+
+ eq_(row, (data_element,))
+ eq_(js.mock_calls, [mock.call(data_element)])
+ eq_(jd.mock_calls, [mock.call(json.dumps(data_element))])
+
+ @testing.combinations(
+ ("parameters",),
+ ("multiparameters",),
+ ("values",),
+ ("omit",),
+ argnames="insert_type",
+ )
+ def test_round_trip_none_as_sql_null(self, connection, insert_type):
+ col = self.tables.data_table.c["nulldata"]
+
+ conn = connection
+
+ if insert_type == "parameters":
+ stmt, params = self.tables.data_table.insert(), {
+ "name": "r1",
+ "nulldata": None,
+ "data": None,
+ }
+ elif insert_type == "multiparameters":
+ stmt, params = self.tables.data_table.insert(), [
+ {"name": "r1", "nulldata": None, "data": None}
+ ]
+ elif insert_type == "values":
+ stmt, params = (
+ self.tables.data_table.insert().values(
+ name="r1",
+ nulldata=None,
+ data=None,
+ ),
+ {},
+ )
+ elif insert_type == "omit":
+ stmt, params = (
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": None},
+ )
+
+ else:
+ assert False
+
+ conn.execute(stmt, params)
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(col.is_(null()))
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ def test_round_trip_json_null_as_json_null(self, connection):
+ col = self.tables.data_table.c["data"]
+
+ conn = connection
+ conn.execute(
+ self.tables.data_table.insert(),
+ {"name": "r1", "data": JSON.NULL},
+ )
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(
+ cast(col, String) == "null"
+ )
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ @testing.combinations(
+ ("parameters",),
+ ("multiparameters",),
+ ("values",),
+ argnames="insert_type",
+ )
+ def test_round_trip_none_as_json_null(self, connection, insert_type):
+ col = self.tables.data_table.c["data"]
+
+ if insert_type == "parameters":
+ stmt, params = self.tables.data_table.insert(), {
+ "name": "r1",
+ "data": None,
+ }
+ elif insert_type == "multiparameters":
+ stmt, params = self.tables.data_table.insert(), [
+ {"name": "r1", "data": None}
+ ]
+ elif insert_type == "values":
+ stmt, params = (
+ self.tables.data_table.insert().values(name="r1", data=None),
+ {},
+ )
+ else:
+ assert False
+
+ conn = connection
+ conn.execute(stmt, params)
+
+ eq_(
+ conn.scalar(
+ select(self.tables.data_table.c.name).where(
+ cast(col, String) == "null"
+ )
+ ),
+ "r1",
+ )
+
+ eq_(conn.scalar(select(col)), None)
+
+ def test_unicode_round_trip(self):
+ # note we include Unicode supplementary characters as well
+ with config.db.begin() as conn:
+ conn.execute(
+ self.tables.data_table.insert(),
+ {
+ "name": "r1",
+ "data": {
+ util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+ "data": {"k1": util.u("drôl🐍e")},
+ },
+ },
+ )
+
+ eq_(
+ conn.scalar(select(self.tables.data_table.c.data)),
+ {
+ util.u("réve🐍 illé"): util.u("réve🐍 illé"),
+ "data": {"k1": util.u("drôl🐍e")},
+ },
+ )
+
+ def test_eval_none_flag_orm(self, connection):
+
+ Base = declarative_base()
+
+ class Data(Base):
+ __table__ = self.tables.data_table
+
+ with Session(connection) as s:
+ d1 = Data(name="d1", data=None, nulldata=None)
+ s.add(d1)
+ s.commit()
+
+ s.bulk_insert_mappings(
+ Data, [{"name": "d2", "data": None, "nulldata": None}]
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d1")
+ .first(),
+ ("null", None),
+ )
+ eq_(
+ s.query(
+ cast(self.tables.data_table.c.data, String()),
+ cast(self.tables.data_table.c.nulldata, String),
+ )
+ .filter(self.tables.data_table.c.name == "d2")
+ .first(),
+ ("null", None),
+ )
+
+
+class JSONLegacyStringCastIndexTest(
+ _LiteralRoundTripFixture, fixtures.TablesTest
+):
+ """test JSON index access with "cast to string", which we have documented
+ for a long time as how to compare JSON values, but is ultimately not
+ reliable in all cases. The "as_XYZ()" comparators should be used
+ instead.
+
+ """
+
+ __requires__ = ("json_type", "legacy_unconditional_json_extract")
+ __backend__ = True
+
+ datatype = JSON
+
+ data1 = {"key1": "value1", "key2": "value2"}
+
+ data2 = {
+ "Key 'One'": "value1",
+ "key two": "value2",
+ "key three": "value ' three '",
+ }
+
+ data3 = {
+ "key1": [1, 2, 3],
+ "key2": ["one", "two", "three"],
+ "key3": [{"four": "five"}, {"six": "seven"}],
+ }
+
+ data4 = ["one", "two", "three"]
+
+ data5 = {
+ "nested": {
+ "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}],
+ "elem2": {"elem3": {"elem4": "elem5"}},
+ }
+ }
+
+ data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}}
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "data_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("name", String(30), nullable=False),
+ Column("data", cls.datatype),
+ Column("nulldata", cls.datatype(none_as_null=True)),
+ )
+
+ def _criteria_fixture(self):
+ with config.db.begin() as conn:
+ conn.execute(
+ self.tables.data_table.insert(),
+ [
+ {"name": "r1", "data": self.data1},
+ {"name": "r2", "data": self.data2},
+ {"name": "r3", "data": self.data3},
+ {"name": "r4", "data": self.data4},
+ {"name": "r5", "data": self.data5},
+ {"name": "r6", "data": self.data6},
+ ],
+ )
+
+ def _test_index_criteria(self, crit, expected, test_literal=True):
+ self._criteria_fixture()
+ with config.db.connect() as conn:
+ stmt = select(self.tables.data_table.c.name).where(crit)
+
+ eq_(conn.scalar(stmt), expected)
+
+ if test_literal:
+ literal_sql = str(
+ stmt.compile(
+ config.db, compile_kwargs={"literal_binds": True}
+ )
+ )
+
+ eq_(conn.exec_driver_sql(literal_sql).scalar(), expected)
+
+ def test_string_cast_crit_spaces_in_key(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ # limit the rows here to avoid PG error
+ # "cannot extract field from a non-object", which is
+ # fixed in 9.4 but may exist in 9.3
+ self._test_index_criteria(
+ and_(
+ name.in_(["r1", "r2", "r3"]),
+ cast(col["key two"], String) == '"value2"',
+ ),
+ "r2",
+ )
+
+ @config.requirements.json_array_indexes
+ def test_string_cast_crit_simple_int(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ # limit the rows here to avoid PG error
+ # "cannot extract array element from a non-array", which is
+ # fixed in 9.4 but may exist in 9.3
+ self._test_index_criteria(
+ and_(
+ name == "r4",
+ cast(col[1], String) == '"two"',
+ ),
+ "r4",
+ )
+
+ def test_string_cast_crit_mixed_path(self):
+ col = self.tables.data_table.c["data"]
+ self._test_index_criteria(
+ cast(col[("key3", 1, "six")], String) == '"seven"',
+ "r3",
+ )
+
+ def test_string_cast_crit_string_path(self):
+ col = self.tables.data_table.c["data"]
+ self._test_index_criteria(
+ cast(col[("nested", "elem2", "elem3", "elem4")], String)
+ == '"elem5"',
+ "r5",
+ )
+
+ def test_string_cast_crit_against_string_basic(self):
+ name = self.tables.data_table.c.name
+ col = self.tables.data_table.c["data"]
+
+ self._test_index_criteria(
+ and_(
+ name == "r6",
+ cast(col["b"], String) == '"some value"',
+ ),
+ "r6",
+ )
+
+
+__all__ = (
+ "BinaryTest",
+ "UnicodeVarcharTest",
+ "UnicodeTextTest",
+ "JSONTest",
+ "JSONLegacyStringCastIndexTest",
+ "DateTest",
+ "DateTimeTest",
+ "DateTimeTZTest",
+ "TextTest",
+ "NumericTest",
+ "IntegerTest",
+ "CastTypeDecoratorTest",
+ "DateTimeHistoricTest",
+ "DateTimeCoercedToDateTimeTest",
+ "TimeMicrosecondsTest",
+ "TimestampMicrosecondsTest",
+ "TimeTest",
+ "TimeTZTest",
+ "DateTimeMicrosecondsTest",
+ "DateHistoricTest",
+ "StringTest",
+ "BooleanTest",
+)
diff --git a/lib/sqlalchemy/testing/suite/test_unicode_ddl.py b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py
new file mode 100644
index 0000000..a4ae334
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py
@@ -0,0 +1,206 @@
+# coding: utf-8
+"""verrrrry basic unicode column name testing"""
+
+from sqlalchemy import desc
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import testing
+from sqlalchemy import util
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+from sqlalchemy.util import u
+from sqlalchemy.util import ue
+
+
+class UnicodeSchemaTest(fixtures.TablesTest):
+ __requires__ = ("unicode_ddl",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ global t1, t2, t3
+
+ t1 = Table(
+ u("unitable1"),
+ metadata,
+ Column(u("méil"), Integer, primary_key=True),
+ Column(ue("\u6e2c\u8a66"), Integer),
+ test_needs_fk=True,
+ )
+ t2 = Table(
+ u("Unitéble2"),
+ metadata,
+ Column(u("méil"), Integer, primary_key=True, key="a"),
+ Column(
+ ue("\u6e2c\u8a66"),
+ Integer,
+ ForeignKey(u("unitable1.méil")),
+ key="b",
+ ),
+ test_needs_fk=True,
+ )
+
+ # Few DBs support Unicode foreign keys
+ if testing.against("sqlite"):
+ t3 = Table(
+ ue("\u6e2c\u8a66"),
+ metadata,
+ Column(
+ ue("\u6e2c\u8a66_id"),
+ Integer,
+ primary_key=True,
+ autoincrement=False,
+ ),
+ Column(
+ ue("unitable1_\u6e2c\u8a66"),
+ Integer,
+ ForeignKey(ue("unitable1.\u6e2c\u8a66")),
+ ),
+ Column(
+ u("Unitéble2_b"), Integer, ForeignKey(u("Unitéble2.b"))
+ ),
+ Column(
+ ue("\u6e2c\u8a66_self"),
+ Integer,
+ ForeignKey(ue("\u6e2c\u8a66.\u6e2c\u8a66_id")),
+ ),
+ test_needs_fk=True,
+ )
+ else:
+ t3 = Table(
+ ue("\u6e2c\u8a66"),
+ metadata,
+ Column(
+ ue("\u6e2c\u8a66_id"),
+ Integer,
+ primary_key=True,
+ autoincrement=False,
+ ),
+ Column(ue("unitable1_\u6e2c\u8a66"), Integer),
+ Column(u("Unitéble2_b"), Integer),
+ Column(ue("\u6e2c\u8a66_self"), Integer),
+ test_needs_fk=True,
+ )
+
+ def test_insert(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ eq_(connection.execute(t1.select()).fetchall(), [(1, 5)])
+ eq_(connection.execute(t2.select()).fetchall(), [(1, 1)])
+ eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)])
+
+ def test_col_targeting(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(t2.insert(), {u("a"): 1, u("b"): 1})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ row = connection.execute(t1.select()).first()
+ eq_(row._mapping[t1.c[u("méil")]], 1)
+ eq_(row._mapping[t1.c[ue("\u6e2c\u8a66")]], 5)
+
+ row = connection.execute(t2.select()).first()
+ eq_(row._mapping[t2.c[u("a")]], 1)
+ eq_(row._mapping[t2.c[u("b")]], 1)
+
+ row = connection.execute(t3.select()).first()
+ eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_id")]], 1)
+ eq_(row._mapping[t3.c[ue("unitable1_\u6e2c\u8a66")]], 5)
+ eq_(row._mapping[t3.c[u("Unitéble2_b")]], 1)
+ eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_self")]], 1)
+
+ def test_reflect(self, connection):
+ connection.execute(t1.insert(), {u("méil"): 2, ue("\u6e2c\u8a66"): 7})
+ connection.execute(t2.insert(), {u("a"): 2, u("b"): 2})
+ connection.execute(
+ t3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 2,
+ ue("unitable1_\u6e2c\u8a66"): 7,
+ u("Unitéble2_b"): 2,
+ ue("\u6e2c\u8a66_self"): 2,
+ },
+ )
+
+ meta = MetaData()
+ tt1 = Table(t1.name, meta, autoload_with=connection)
+ tt2 = Table(t2.name, meta, autoload_with=connection)
+ tt3 = Table(t3.name, meta, autoload_with=connection)
+
+ connection.execute(tt1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5})
+ connection.execute(tt2.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 1})
+ connection.execute(
+ tt3.insert(),
+ {
+ ue("\u6e2c\u8a66_id"): 1,
+ ue("unitable1_\u6e2c\u8a66"): 5,
+ u("Unitéble2_b"): 1,
+ ue("\u6e2c\u8a66_self"): 1,
+ },
+ )
+
+ eq_(
+ connection.execute(
+ tt1.select().order_by(desc(u("méil")))
+ ).fetchall(),
+ [(2, 7), (1, 5)],
+ )
+ eq_(
+ connection.execute(
+ tt2.select().order_by(desc(u("méil")))
+ ).fetchall(),
+ [(2, 2), (1, 1)],
+ )
+ eq_(
+ connection.execute(
+ tt3.select().order_by(desc(ue("\u6e2c\u8a66_id")))
+ ).fetchall(),
+ [(2, 7, 2, 2), (1, 5, 1, 1)],
+ )
+
+ def test_repr(self):
+ meta = MetaData()
+ t = Table(
+ ue("\u6e2c\u8a66"), meta, Column(ue("\u6e2c\u8a66_id"), Integer)
+ )
+
+ if util.py2k:
+ eq_(
+ repr(t),
+ (
+ "Table('\\u6e2c\\u8a66', MetaData(), "
+ "Column('\\u6e2c\\u8a66_id', Integer(), "
+ "table=<\u6e2c\u8a66>), "
+ "schema=None)"
+ ),
+ )
+ else:
+ eq_(
+ repr(t),
+ (
+ "Table('測試', MetaData(), "
+ "Column('測試_id', Integer(), "
+ "table=<測試>), "
+ "schema=None)"
+ ),
+ )
diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py
new file mode 100644
index 0000000..f04a9d5
--- /dev/null
+++ b/lib/sqlalchemy/testing/suite/test_update_delete.py
@@ -0,0 +1,60 @@
+from .. import fixtures
+from ..assertions import eq_
+from ..schema import Column
+from ..schema import Table
+from ... import Integer
+from ... import String
+
+
+class SimpleUpdateDeleteTest(fixtures.TablesTest):
+ run_deletes = "each"
+ __requires__ = ("sane_rowcount",)
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "plain_pk",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.plain_pk.insert(),
+ [
+ {"id": 1, "data": "d1"},
+ {"id": 2, "data": "d2"},
+ {"id": 3, "data": "d3"},
+ ],
+ )
+
+ def test_update(self, connection):
+ t = self.tables.plain_pk
+ r = connection.execute(
+ t.update().where(t.c.id == 2), dict(data="d2_new")
+ )
+ assert not r.is_insert
+ assert not r.returns_rows
+ assert r.rowcount == 1
+
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "d1"), (2, "d2_new"), (3, "d3")],
+ )
+
+ def test_delete(self, connection):
+ t = self.tables.plain_pk
+ r = connection.execute(t.delete().where(t.c.id == 2))
+ assert not r.is_insert
+ assert not r.returns_rows
+ assert r.rowcount == 1
+ eq_(
+ connection.execute(t.select().order_by(t.c.id)).fetchall(),
+ [(1, "d1"), (3, "d3")],
+ )
+
+
+__all__ = ("SimpleUpdateDeleteTest",)