diff options
author | xiubuzhe <xiubuzhe@sina.com> | 2023-10-08 20:59:00 +0800 |
---|---|---|
committer | xiubuzhe <xiubuzhe@sina.com> | 2023-10-08 20:59:00 +0800 |
commit | 1dac2263372df2b85db5d029a45721fa158a5c9d (patch) | |
tree | 0365f9c57df04178a726d7584ca6a6b955a7ce6a /lib/sqlalchemy/testing/suite | |
parent | b494be364bb39e1de128ada7dc576a729d99907e (diff) | |
download | sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.gz sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.tar.bz2 sunhpc-1dac2263372df2b85db5d029a45721fa158a5c9d.zip |
first add files
Diffstat (limited to 'lib/sqlalchemy/testing/suite')
-rw-r--r-- | lib/sqlalchemy/testing/suite/__init__.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_cte.py | 204 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_ddl.py | 381 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_deprecations.py | 145 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_dialect.py | 361 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_insert.py | 367 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_reflection.py | 1738 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_results.py | 426 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_rowcount.py | 165 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_select.py | 1783 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_sequence.py | 282 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 1508 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_unicode_ddl.py | 206 | ||||
-rw-r--r-- | lib/sqlalchemy/testing/suite/test_update_delete.py | 60 |
14 files changed, 7639 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/suite/__init__.py b/lib/sqlalchemy/testing/suite/__init__.py new file mode 100644 index 0000000..30817e1 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/__init__.py @@ -0,0 +1,13 @@ +from .test_cte import * # noqa +from .test_ddl import * # noqa +from .test_deprecations import * # noqa +from .test_dialect import * # noqa +from .test_insert import * # noqa +from .test_reflection import * # noqa +from .test_results import * # noqa +from .test_rowcount import * # noqa +from .test_select import * # noqa +from .test_sequence import * # noqa +from .test_types import * # noqa +from .test_unicode_ddl import * # noqa +from .test_update_delete import * # noqa diff --git a/lib/sqlalchemy/testing/suite/test_cte.py b/lib/sqlalchemy/testing/suite/test_cte.py new file mode 100644 index 0000000..a94ee55 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_cte.py @@ -0,0 +1,204 @@ +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import ForeignKey +from ... import Integer +from ... import select +from ... import String +from ... import testing + + +class CTETest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("ctes",) + + run_inserts = "each" + run_deletes = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", ForeignKey("some_table.id")), + ) + + Table( + "some_other_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + Column("parent_id", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "d1", "parent_id": None}, + {"id": 2, "data": "d2", "parent_id": 1}, + {"id": 3, "data": "d3", "parent_id": 1}, + {"id": 4, "data": "d4", "parent_id": 3}, + {"id": 5, "data": "d5", "parent_id": 3}, + ], + ) + + def test_select_nonrecursive_round_trip(self, connection): + some_table = self.tables.some_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + result = connection.execute( + select(cte.c.data).where(cte.c.data.in_(["d4", "d5"])) + ) + eq_(result.fetchall(), [("d4",)]) + + def test_select_recursive_round_trip(self, connection): + some_table = self.tables.some_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte", recursive=True) + ) + + cte_alias = cte.alias("c1") + st1 = some_table.alias() + # note that SQL Server requires this to be UNION ALL, + # can't be UNION + cte = cte.union_all( + select(st1).where(st1.c.id == cte_alias.c.parent_id) + ) + result = connection.execute( + select(cte.c.data) + .where(cte.c.data != "d2") + .order_by(cte.c.data.desc()) + ) + eq_( + result.fetchall(), + [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)], + ) + + def test_insert_from_select_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(cte) + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)], + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.update_from + def test_update_from_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.update() + .values(parent_id=5) + .where(some_other_table.c.data == cte.c.data) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [ + (1, "d1", None), + (2, "d2", 5), + (3, "d3", 5), + (4, "d4", 5), + (5, "d5", 3), + ], + ) + + @testing.requires.ctes_with_update_delete + @testing.requires.delete_from + def test_delete_from_round_trip(self, connection): + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data == cte.c.data + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) + + @testing.requires.ctes_with_update_delete + def test_delete_scalar_subq_round_trip(self, connection): + + some_table = self.tables.some_table + some_other_table = self.tables.some_other_table + + connection.execute( + some_other_table.insert().from_select( + ["id", "data", "parent_id"], select(some_table) + ) + ) + + cte = ( + select(some_table) + .where(some_table.c.data.in_(["d2", "d3", "d4"])) + .cte("some_cte") + ) + connection.execute( + some_other_table.delete().where( + some_other_table.c.data + == select(cte.c.data) + .where(cte.c.id == some_other_table.c.id) + .scalar_subquery() + ) + ) + eq_( + connection.execute( + select(some_other_table).order_by(some_other_table.c.id) + ).fetchall(), + [(1, "d1", None), (5, "d5", 3)], + ) diff --git a/lib/sqlalchemy/testing/suite/test_ddl.py b/lib/sqlalchemy/testing/suite/test_ddl.py new file mode 100644 index 0000000..b3fee55 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_ddl.py @@ -0,0 +1,381 @@ +import random + +from . import testing +from .. import config +from .. import fixtures +from .. import util +from ..assertions import eq_ +from ..assertions import is_false +from ..assertions import is_true +from ..config import requirements +from ..schema import Table +from ... import CheckConstraint +from ... import Column +from ... import ForeignKeyConstraint +from ... import Index +from ... import inspect +from ... import Integer +from ... import schema +from ... import String +from ... import UniqueConstraint + + +class TableDDLTest(fixtures.TestBase): + __backend__ = True + + def _simple_fixture(self, schema=None): + return Table( + "test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + schema=schema, + ) + + def _underscore_fixture(self): + return Table( + "_test_table", + self.metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("_data", String(50)), + ) + + def _table_index_fixture(self, schema=None): + table = self._simple_fixture(schema=schema) + idx = Index("test_index", table.c.data) + return table, idx + + def _simple_roundtrip(self, table): + with config.db.begin() as conn: + conn.execute(table.insert().values((1, "some data"))) + result = conn.execute(table.select()) + eq_(result.first(), (1, "some data")) + + @requirements.create_table + @util.provide_metadata + def test_create_table(self): + table = self._simple_fixture() + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.create_table + @requirements.schemas + @util.provide_metadata + def test_create_table_schema(self): + table = self._simple_fixture(schema=config.test_schema) + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.drop_table + @util.provide_metadata + def test_drop_table(self): + table = self._simple_fixture() + table.create(config.db, checkfirst=False) + table.drop(config.db, checkfirst=False) + + @requirements.create_table + @util.provide_metadata + def test_underscore_names(self): + table = self._underscore_fixture() + table.create(config.db, checkfirst=False) + self._simple_roundtrip(table) + + @requirements.comment_reflection + @util.provide_metadata + def test_add_table_comment(self, connection): + table = self._simple_fixture() + table.create(connection, checkfirst=False) + table.comment = "a comment" + connection.execute(schema.SetTableComment(table)) + eq_( + inspect(connection).get_table_comment("test_table"), + {"text": "a comment"}, + ) + + @requirements.comment_reflection + @util.provide_metadata + def test_drop_table_comment(self, connection): + table = self._simple_fixture() + table.create(connection, checkfirst=False) + table.comment = "a comment" + connection.execute(schema.SetTableComment(table)) + connection.execute(schema.DropTableComment(table)) + eq_( + inspect(connection).get_table_comment("test_table"), {"text": None} + ) + + @requirements.table_ddl_if_exists + @util.provide_metadata + def test_create_table_if_not_exists(self, connection): + table = self._simple_fixture() + + connection.execute(schema.CreateTable(table, if_not_exists=True)) + + is_true(inspect(connection).has_table("test_table")) + connection.execute(schema.CreateTable(table, if_not_exists=True)) + + @requirements.index_ddl_if_exists + @util.provide_metadata + def test_create_index_if_not_exists(self, connection): + table, idx = self._table_index_fixture() + + connection.execute(schema.CreateTable(table, if_not_exists=True)) + is_true(inspect(connection).has_table("test_table")) + is_false( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.CreateIndex(idx, if_not_exists=True)) + + is_true( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.CreateIndex(idx, if_not_exists=True)) + + @requirements.table_ddl_if_exists + @util.provide_metadata + def test_drop_table_if_exists(self, connection): + table = self._simple_fixture() + + table.create(connection) + + is_true(inspect(connection).has_table("test_table")) + + connection.execute(schema.DropTable(table, if_exists=True)) + + is_false(inspect(connection).has_table("test_table")) + + connection.execute(schema.DropTable(table, if_exists=True)) + + @requirements.index_ddl_if_exists + @util.provide_metadata + def test_drop_index_if_exists(self, connection): + table, idx = self._table_index_fixture() + + table.create(connection) + + is_true( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.DropIndex(idx, if_exists=True)) + + is_false( + "test_index" + in [ + ix["name"] + for ix in inspect(connection).get_indexes("test_table") + ] + ) + + connection.execute(schema.DropIndex(idx, if_exists=True)) + + +class FutureTableDDLTest(fixtures.FutureEngineMixin, TableDDLTest): + pass + + +class LongNameBlowoutTest(fixtures.TestBase): + """test the creation of a variety of DDL structures and ensure + label length limits pass on backends + + """ + + __backend__ = True + + def fk(self, metadata, connection): + convention = { + "fk": "foreign_key_%(table_name)s_" + "%(column_0_N_name)s_" + "%(referred_table_name)s_" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(20)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + test_needs_fk=True, + ) + + cons = ForeignKeyConstraint( + ["aid"], ["a_things_with_stuff.id_long_column_name"] + ) + Table( + "b_related_things_of_value", + metadata, + Column( + "aid", + ), + cons, + test_needs_fk=True, + ) + actual_name = cons.name + + metadata.create_all(connection) + + if testing.requires.foreign_key_constraint_name_reflection.enabled: + insp = inspect(connection) + fks = insp.get_foreign_keys("b_related_things_of_value") + reflected_name = fks[0]["name"] + + return actual_name, reflected_name + else: + return actual_name, None + + def pk(self, metadata, connection): + convention = { + "pk": "primary_key_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + a = Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer, primary_key=True), + ) + cons = a.primary_key + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + pk = insp.get_pk_constraint("a_things_with_stuff") + reflected_name = pk["name"] + return actual_name, reflected_name + + def ix(self, metadata, connection): + convention = { + "ix": "index_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + a = Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer), + ) + cons = Index(None, a.c.id_long_column_name, a.c.id_another_long_name) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + ix = insp.get_indexes("a_things_with_stuff") + reflected_name = ix[0]["name"] + return actual_name, reflected_name + + def uq(self, metadata, connection): + convention = { + "uq": "unique_constraint_%(table_name)s_" + "%(column_0_N_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + cons = UniqueConstraint("id_long_column_name", "id_another_long_name") + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("id_another_long_name", Integer), + cons, + ) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + uq = insp.get_unique_constraints("a_things_with_stuff") + reflected_name = uq[0]["name"] + return actual_name, reflected_name + + def ck(self, metadata, connection): + convention = { + "ck": "check_constraint_%(table_name)s" + + ( + "_".join( + "".join(random.choice("abcdef") for j in range(30)) + for i in range(10) + ) + ), + } + metadata.naming_convention = convention + + cons = CheckConstraint("some_long_column_name > 5") + Table( + "a_things_with_stuff", + metadata, + Column("id_long_column_name", Integer, primary_key=True), + Column("some_long_column_name", Integer), + cons, + ) + actual_name = cons.name + + metadata.create_all(connection) + insp = inspect(connection) + ck = insp.get_check_constraints("a_things_with_stuff") + reflected_name = ck[0]["name"] + return actual_name, reflected_name + + @testing.combinations( + ("fk",), + ("pk",), + ("ix",), + ("ck", testing.requires.check_constraint_reflection.as_skips()), + ("uq", testing.requires.unique_constraint_reflection.as_skips()), + argnames="type_", + ) + def test_long_convention_name(self, type_, metadata, connection): + actual_name, reflected_name = getattr(self, type_)( + metadata, connection + ) + + assert len(actual_name) > 255 + + if reflected_name is not None: + overlap = actual_name[0 : len(reflected_name)] + if len(overlap) < len(actual_name): + eq_(overlap[0:-5], reflected_name[0 : len(overlap) - 5]) + else: + eq_(overlap, reflected_name) + + +__all__ = ("TableDDLTest", "FutureTableDDLTest", "LongNameBlowoutTest") diff --git a/lib/sqlalchemy/testing/suite/test_deprecations.py b/lib/sqlalchemy/testing/suite/test_deprecations.py new file mode 100644 index 0000000..b36162f --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_deprecations.py @@ -0,0 +1,145 @@ +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import select +from ... import testing +from ... import union + + +class DeprecatedCompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + def _assert_result(self, conn, select, result, params=()): + eq_(conn.execute(select, params).fetchall(), result) + + def test_plain_union(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + # note we've had to remove one use case entirely, which is this + # one. the Select gets its FROMS from the WHERE clause and the + # columns clause, but not the ORDER BY, which means the old ".c" system + # allowed you to "order_by(s.c.foo)" to get an unnamed column in the + # ORDER BY without adding the SELECT into the FROM and breaking the + # query. Users will have to adjust for this use case if they were doing + # it before. + def _dont_test_select_from_plain_union(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.order_by_col_from_union + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self, connection): + table = self.tables.some_table + s1 = ( + select(table) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select(table) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + with testing.expect_deprecated( + "The SelectBase.c and SelectBase.columns " + "attributes are deprecated" + ): + self._assert_result( + connection, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) diff --git a/lib/sqlalchemy/testing/suite/test_dialect.py b/lib/sqlalchemy/testing/suite/test_dialect.py new file mode 100644 index 0000000..c2c17d0 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_dialect.py @@ -0,0 +1,361 @@ +#! coding: utf-8 + +from . import testing +from .. import assert_raises +from .. import config +from .. import engines +from .. import eq_ +from .. import fixtures +from .. import ne_ +from .. import provide_metadata +from ..config import requirements +from ..provision import set_default_schema_on_connection +from ..schema import Column +from ..schema import Table +from ... import bindparam +from ... import event +from ... import exc +from ... import Integer +from ... import literal_column +from ... import select +from ... import String +from ...util import compat + + +class ExceptionTest(fixtures.TablesTest): + """Test basic exception wrapping. + + DBAPIs vary a lot in exception behavior so to actually anticipate + specific exceptions from real round trips, we need to be conservative. + + """ + + run_deletes = "each" + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + @requirements.duplicate_key_raises_integrity_error + def test_integrity_error(self): + + with config.db.connect() as conn: + + trans = conn.begin() + conn.execute( + self.tables.manual_pk.insert(), {"id": 1, "data": "d1"} + ) + + assert_raises( + exc.IntegrityError, + conn.execute, + self.tables.manual_pk.insert(), + {"id": 1, "data": "d1"}, + ) + + trans.rollback() + + def test_exception_with_non_ascii(self): + with config.db.connect() as conn: + try: + # try to create an error message that likely has non-ascii + # characters in the DBAPI's message string. unfortunately + # there's no way to make this happen with some drivers like + # mysqlclient, pymysql. this at least does produce a non- + # ascii error message for cx_oracle, psycopg2 + conn.execute(select(literal_column(u"méil"))) + assert False + except exc.DBAPIError as err: + err_str = str(err) + + assert str(err.orig) in str(err) + + # test that we are actually getting string on Py2k, unicode + # on Py3k. + if compat.py2k: + assert isinstance(err_str, str) + else: + assert isinstance(err_str, str) + + +class IsolationLevelTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("isolation_level",) + + def _get_non_default_isolation_level(self): + levels = requirements.get_isolation_levels(config) + + default = levels["default"] + supported = levels["supported"] + + s = set(supported).difference(["AUTOCOMMIT", default]) + if s: + return s.pop() + else: + config.skip_test("no non-default isolation level available") + + def test_default_isolation_level(self): + eq_( + config.db.dialect.default_isolation_level, + requirements.get_isolation_levels(config)["default"], + ) + + def test_non_default_isolation_level(self): + non_default = self._get_non_default_isolation_level() + + with config.db.connect() as conn: + existing = conn.get_isolation_level() + + ne_(existing, non_default) + + conn.execution_options(isolation_level=non_default) + + eq_(conn.get_isolation_level(), non_default) + + conn.dialect.reset_isolation_level(conn.connection) + + eq_(conn.get_isolation_level(), existing) + + def test_all_levels(self): + levels = requirements.get_isolation_levels(config) + + all_levels = levels["supported"] + + for level in set(all_levels).difference(["AUTOCOMMIT"]): + with config.db.connect() as conn: + conn.execution_options(isolation_level=level) + + eq_(conn.get_isolation_level(), level) + + trans = conn.begin() + trans.rollback() + + eq_(conn.get_isolation_level(), level) + + with config.db.connect() as conn: + eq_( + conn.get_isolation_level(), + levels["default"], + ) + + +class AutocommitIsolationTest(fixtures.TablesTest): + + run_deletes = "each" + + __requires__ = ("autocommit",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + test_needs_acid=True, + ) + + def _test_conn_autocommits(self, conn, autocommit): + trans = conn.begin() + conn.execute( + self.tables.some_table.insert(), {"id": 1, "data": "some data"} + ) + trans.rollback() + + eq_( + conn.scalar(select(self.tables.some_table.c.id)), + 1 if autocommit else None, + ) + + with conn.begin(): + conn.execute(self.tables.some_table.delete()) + + def test_autocommit_on(self, connection_no_trans): + conn = connection_no_trans + c2 = conn.execution_options(isolation_level="AUTOCOMMIT") + self._test_conn_autocommits(c2, True) + + c2.dialect.reset_isolation_level(c2.connection) + + self._test_conn_autocommits(conn, False) + + def test_autocommit_off(self, connection_no_trans): + conn = connection_no_trans + self._test_conn_autocommits(conn, False) + + def test_turn_autocommit_off_via_default_iso_level( + self, connection_no_trans + ): + conn = connection_no_trans + conn = conn.execution_options(isolation_level="AUTOCOMMIT") + self._test_conn_autocommits(conn, True) + + conn.execution_options( + isolation_level=requirements.get_isolation_levels(config)[ + "default" + ] + ) + self._test_conn_autocommits(conn, False) + + +class EscapingTest(fixtures.TestBase): + @provide_metadata + def test_percent_sign_round_trip(self): + """test that the DBAPI accommodates for escaped / nonescaped + percent signs in a way that matches the compiler + + """ + m = self.metadata + t = Table("t", m, Column("data", String(50))) + t.create(config.db) + with config.db.begin() as conn: + conn.execute(t.insert(), dict(data="some % value")) + conn.execute(t.insert(), dict(data="some %% other value")) + + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some % value'") + ) + ), + "some % value", + ) + + eq_( + conn.scalar( + select(t.c.data).where( + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", + ) + + +class WeCanSetDefaultSchemaWEventsTest(fixtures.TestBase): + __backend__ = True + + __requires__ = ("default_schema_name_switch",) + + def test_control_case(self): + default_schema_name = config.db.dialect.default_schema_name + + eng = engines.testing_engine() + with eng.connect(): + pass + + eq_(eng.dialect.default_schema_name, default_schema_name) + + def test_wont_work_wo_insert(self): + default_schema_name = config.db.dialect.default_schema_name + + eng = engines.testing_engine() + + @event.listens_for(eng, "connect") + def on_connect(dbapi_connection, connection_record): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, default_schema_name) + + def test_schema_change_on_connect(self): + eng = engines.testing_engine() + + @event.listens_for(eng, "connect", insert=True) + def on_connect(dbapi_connection, connection_record): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, config.test_schema) + + def test_schema_change_works_w_transactions(self): + eng = engines.testing_engine() + + @event.listens_for(eng, "connect", insert=True) + def on_connect(dbapi_connection, *arg): + set_default_schema_on_connection( + config, dbapi_connection, config.test_schema + ) + + with eng.connect() as conn: + trans = conn.begin() + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + trans.rollback() + + what_it_should_be = eng.dialect._get_default_schema_name(conn) + eq_(what_it_should_be, config.test_schema) + + eq_(eng.dialect.default_schema_name, config.test_schema) + + +class FutureWeCanSetDefaultSchemaWEventsTest( + fixtures.FutureEngineMixin, WeCanSetDefaultSchemaWEventsTest +): + pass + + +class DifficultParametersTest(fixtures.TestBase): + __backend__ = True + + @testing.combinations( + ("boring",), + ("per cent",), + ("per % cent",), + ("%percent",), + ("par(ens)",), + ("percent%(ens)yah",), + ("col:ons",), + ("more :: %colons%",), + ("/slashes/",), + ("more/slashes",), + ("q?marks",), + ("1param",), + ("1col:on",), + argnames="name", + ) + def test_round_trip(self, name, connection, metadata): + t = Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column(name, String(50), nullable=False), + ) + + # table is created + t.create(connection) + + # automatic param generated by insert + connection.execute(t.insert().values({"id": 1, name: "some name"})) + + # automatic param generated by criteria, plus selecting the column + stmt = select(t.c[name]).where(t.c[name] == "some name") + + eq_(connection.scalar(stmt), "some name") + + # use the name in a param explicitly + stmt = select(t.c[name]).where(t.c[name] == bindparam(name)) + + row = connection.execute(stmt, {name: "some name"}).first() + + # name works as the key from cursor.description + eq_(row._mapping[name], "some name") diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py new file mode 100644 index 0000000..3c22f50 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -0,0 +1,367 @@ +from .. import config +from .. import engines +from .. import fixtures +from ..assertions import eq_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import literal +from ... import literal_column +from ... import select +from ... import String + + +class LastrowidTest(fixtures.TablesTest): + run_deletes = "each" + + __backend__ = True + + __requires__ = "implements_get_lastrowid", "autoincrement_insert" + + __engine_options__ = {"implicit_returning": False} + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + ( + conn.dialect.default_sequence_base, + "some data", + ), + ) + + def test_autoincrement_on_insert(self, connection): + + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.autoinc_pk, connection) + + def test_last_inserted_id(self, connection): + + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(r.inserted_primary_key, (pk,)) + + @requirements.dbapi_lastrowid + def test_native_lastrowid_autoinc(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + lastrowid = r.lastrowid + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(lastrowid, pk) + + +class InsertBehaviorTest(fixtures.TablesTest): + run_deletes = "each" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + Table( + "manual_pk", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("data", String(50)), + ) + Table( + "includes_defaults", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + Column("x", Integer, default=5), + Column( + "y", + Integer, + default=literal_column("2", type_=Integer) + literal(2), + ), + ) + + @requirements.autoincrement_insert + def test_autoclose_on_insert(self): + if requirements.returning.enabled: + engine = engines.testing_engine( + options={"implicit_returning": False} + ) + else: + engine = config.db + + with engine.begin() as conn: + r = conn.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # new as of I8091919d45421e3f53029b8660427f844fee0228; for the moment + # an insert where the PK was taken from a row that the dialect + # selected, as is the case for mssql/pyodbc, will still report + # returns_rows as true because there's a cursor description. in that + # case, the row had to have been consumed at least. + assert not r.returns_rows or r.fetchone() is None + + @requirements.returning + def test_autoclose_on_insert_implicit_returning(self, connection): + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + assert r._soft_closed + assert not r.closed + assert r.is_insert + + # note we are experimenting with having this be True + # as of I8091919d45421e3f53029b8660427f844fee0228 . + # implicit returning has fetched the row, but it still is a + # "returns rows" + assert r.returns_rows + + # and we should be able to fetchone() on it, we just get no row + eq_(r.fetchone(), None) + + # and the keys, etc. + eq_(r.keys(), ["id"]) + + # but the dialect took in the row already. not really sure + # what the best behavior is. + + @requirements.empty_inserts + def test_empty_insert(self, connection): + r = connection.execute(self.tables.autoinc_pk.insert()) + assert r._soft_closed + assert not r.closed + + r = connection.execute( + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) + ) + eq_(len(r.all()), 1) + + @requirements.empty_inserts_executemany + def test_empty_insert_multiple(self, connection): + r = connection.execute(self.tables.autoinc_pk.insert(), [{}, {}, {}]) + assert r._soft_closed + assert not r.closed + + r = connection.execute( + self.tables.autoinc_pk.select().where( + self.tables.autoinc_pk.c.id != None + ) + ) + + eq_(len(r.all()), 3) + + @requirements.insert_from_select + def test_insert_from_select_autoinc(self, connection): + src_table = self.tables.manual_pk + dest_table = self.tables.autoinc_pk + connection.execute( + src_table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + result = connection.execute( + dest_table.insert().from_select( + ("data",), + select(src_table.c.data).where( + src_table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_(result.inserted_primary_key, (None,)) + + result = connection.execute( + select(dest_table.c.data).order_by(dest_table.c.data) + ) + eq_(result.fetchall(), [("data2",), ("data3",)]) + + @requirements.insert_from_select + def test_insert_from_select_autoinc_no_rows(self, connection): + src_table = self.tables.manual_pk + dest_table = self.tables.autoinc_pk + + result = connection.execute( + dest_table.insert().from_select( + ("data",), + select(src_table.c.data).where( + src_table.c.data.in_(["data2", "data3"]) + ), + ) + ) + eq_(result.inserted_primary_key, (None,)) + + result = connection.execute( + select(dest_table.c.data).order_by(dest_table.c.data) + ) + + eq_(result.fetchall(), []) + + @requirements.insert_from_select + def test_insert_from_select(self, connection): + table = self.tables.manual_pk + connection.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + connection.execute( + table.insert() + .inline() + .from_select( + ("id", "data"), + select(table.c.id + 5, table.c.data).where( + table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_( + connection.execute( + select(table.c.data).order_by(table.c.data) + ).fetchall(), + [("data1",), ("data2",), ("data2",), ("data3",), ("data3",)], + ) + + @requirements.insert_from_select + def test_insert_from_select_with_defaults(self, connection): + table = self.tables.includes_defaults + connection.execute( + table.insert(), + [ + dict(id=1, data="data1"), + dict(id=2, data="data2"), + dict(id=3, data="data3"), + ], + ) + + connection.execute( + table.insert() + .inline() + .from_select( + ("id", "data"), + select(table.c.id + 5, table.c.data).where( + table.c.data.in_(["data2", "data3"]) + ), + ) + ) + + eq_( + connection.execute( + select(table).order_by(table.c.data, table.c.id) + ).fetchall(), + [ + (1, "data1", 5, 4), + (2, "data2", 5, 4), + (7, "data2", 5, 4), + (3, "data3", 5, 4), + (8, "data3", 5, 4), + ], + ) + + +class ReturningTest(fixtures.TablesTest): + run_create_tables = "each" + __requires__ = "returning", "autoincrement_insert" + __backend__ = True + + __engine_options__ = {"implicit_returning": True} + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_( + row, + ( + conn.dialect.default_sequence_base, + "some data", + ), + ) + + @classmethod + def define_tables(cls, metadata): + Table( + "autoinc_pk", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("data", String(50)), + ) + + @requirements.fetch_rows_post_commit + def test_explicit_returning_pk_autocommit(self, connection): + table = self.tables.autoinc_pk + r = connection.execute( + table.insert().returning(table.c.id), dict(data="some data") + ) + pk = r.first()[0] + fetched_pk = connection.scalar(select(table.c.id)) + eq_(fetched_pk, pk) + + def test_explicit_returning_pk_no_autocommit(self, connection): + table = self.tables.autoinc_pk + r = connection.execute( + table.insert().returning(table.c.id), dict(data="some data") + ) + pk = r.first()[0] + fetched_pk = connection.scalar(select(table.c.id)) + eq_(fetched_pk, pk) + + def test_autoincrement_on_insert_implicit_returning(self, connection): + + connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.autoinc_pk, connection) + + def test_last_inserted_id_implicit_returning(self, connection): + + r = connection.execute( + self.tables.autoinc_pk.insert(), dict(data="some data") + ) + pk = connection.scalar(select(self.tables.autoinc_pk.c.id)) + eq_(r.inserted_primary_key, (pk,)) + + +__all__ = ("LastrowidTest", "InsertBehaviorTest", "ReturningTest") diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py new file mode 100644 index 0000000..459a4d8 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -0,0 +1,1738 @@ +import operator +import re + +import sqlalchemy as sa +from .. import config +from .. import engines +from .. import eq_ +from .. import expect_warnings +from .. import fixtures +from .. import is_ +from ..provision import get_temp_table_name +from ..provision import temp_table_keyword_args +from ..schema import Column +from ..schema import Table +from ... import event +from ... import ForeignKey +from ... import func +from ... import Identity +from ... import inspect +from ... import Integer +from ... import MetaData +from ... import String +from ... import testing +from ... import types as sql_types +from ...schema import DDL +from ...schema import Index +from ...sql.elements import quoted_name +from ...sql.schema import BLANK_SCHEMA +from ...testing import is_false +from ...testing import is_true + + +metadata, users = None, None + + +class HasTableTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + if testing.requires.schemas.enabled: + Table( + "test_table_s", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + schema=config.test_schema, + ) + + def test_has_table(self): + with config.db.begin() as conn: + is_true(config.db.dialect.has_table(conn, "test_table")) + is_false(config.db.dialect.has_table(conn, "test_table_s")) + is_false(config.db.dialect.has_table(conn, "nonexistent_table")) + + @testing.requires.schemas + def test_has_table_schema(self): + with config.db.begin() as conn: + is_false( + config.db.dialect.has_table( + conn, "test_table", schema=config.test_schema + ) + ) + is_true( + config.db.dialect.has_table( + conn, "test_table_s", schema=config.test_schema + ) + ) + is_false( + config.db.dialect.has_table( + conn, "nonexistent_table", schema=config.test_schema + ) + ) + + +class HasIndexTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Index("my_idx", tt.c.data) + + if testing.requires.schemas.enabled: + tt = Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + schema=config.test_schema, + ) + Index("my_idx_s", tt.c.data) + + def test_has_index(self): + with config.db.begin() as conn: + assert config.db.dialect.has_index(conn, "test_table", "my_idx") + assert not config.db.dialect.has_index( + conn, "test_table", "my_idx_s" + ) + assert not config.db.dialect.has_index( + conn, "nonexistent_table", "my_idx" + ) + assert not config.db.dialect.has_index( + conn, "test_table", "nonexistent_idx" + ) + + @testing.requires.schemas + def test_has_index_schema(self): + with config.db.begin() as conn: + assert config.db.dialect.has_index( + conn, "test_table", "my_idx_s", schema=config.test_schema + ) + assert not config.db.dialect.has_index( + conn, "test_table", "my_idx", schema=config.test_schema + ) + assert not config.db.dialect.has_index( + conn, + "nonexistent_table", + "my_idx_s", + schema=config.test_schema, + ) + assert not config.db.dialect.has_index( + conn, + "test_table", + "nonexistent_idx_s", + schema=config.test_schema, + ) + + +class QuotedNameArgumentTest(fixtures.TablesTest): + run_create_tables = "once" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "quote ' one", + metadata, + Column("id", Integer), + Column("name", String(50)), + Column("data", String(50)), + Column("related_id", Integer), + sa.PrimaryKeyConstraint("id", name="pk quote ' one"), + sa.Index("ix quote ' one", "name"), + sa.UniqueConstraint( + "data", + name="uq quote' one", + ), + sa.ForeignKeyConstraint( + ["id"], ["related.id"], name="fk quote ' one" + ), + sa.CheckConstraint("name != 'foo'", name="ck quote ' one"), + comment=r"""quote ' one comment""", + test_needs_fk=True, + ) + + if testing.requires.symbol_names_w_double_quote.enabled: + Table( + 'quote " two', + metadata, + Column("id", Integer), + Column("name", String(50)), + Column("data", String(50)), + Column("related_id", Integer), + sa.PrimaryKeyConstraint("id", name='pk quote " two'), + sa.Index('ix quote " two', "name"), + sa.UniqueConstraint( + "data", + name='uq quote" two', + ), + sa.ForeignKeyConstraint( + ["id"], ["related.id"], name='fk quote " two' + ), + sa.CheckConstraint("name != 'foo'", name='ck quote " two '), + comment=r"""quote " two comment""", + test_needs_fk=True, + ) + + Table( + "related", + metadata, + Column("id", Integer, primary_key=True), + Column("related", Integer), + test_needs_fk=True, + ) + + if testing.requires.view_column_reflection.enabled: + + if testing.requires.symbol_names_w_double_quote.enabled: + names = [ + "quote ' one", + 'quote " two', + ] + else: + names = [ + "quote ' one", + ] + for name in names: + query = "CREATE VIEW %s AS SELECT * FROM %s" % ( + config.db.dialect.identifier_preparer.quote( + "view %s" % name + ), + config.db.dialect.identifier_preparer.quote(name), + ) + + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, + "before_drop", + DDL( + "DROP VIEW %s" + % config.db.dialect.identifier_preparer.quote( + "view %s" % name + ) + ), + ) + + def quote_fixtures(fn): + return testing.combinations( + ("quote ' one",), + ('quote " two', testing.requires.symbol_names_w_double_quote), + )(fn) + + @quote_fixtures + def test_get_table_options(self, name): + insp = inspect(config.db) + + insp.get_table_options(name) + + @quote_fixtures + @testing.requires.view_column_reflection + def test_get_view_definition(self, name): + insp = inspect(config.db) + assert insp.get_view_definition("view %s" % name) + + @quote_fixtures + def test_get_columns(self, name): + insp = inspect(config.db) + assert insp.get_columns(name) + + @quote_fixtures + def test_get_pk_constraint(self, name): + insp = inspect(config.db) + assert insp.get_pk_constraint(name) + + @quote_fixtures + def test_get_foreign_keys(self, name): + insp = inspect(config.db) + assert insp.get_foreign_keys(name) + + @quote_fixtures + def test_get_indexes(self, name): + insp = inspect(config.db) + assert insp.get_indexes(name) + + @quote_fixtures + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, name): + insp = inspect(config.db) + assert insp.get_unique_constraints(name) + + @quote_fixtures + @testing.requires.comment_reflection + def test_get_table_comment(self, name): + insp = inspect(config.db) + assert insp.get_table_comment(name) + + @quote_fixtures + @testing.requires.check_constraint_reflection + def test_get_check_constraints(self, name): + insp = inspect(config.db) + assert insp.get_check_constraints(name) + + +class ComponentReflectionTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + + @classmethod + def setup_bind(cls): + if config.requirements.independent_connections.enabled: + from sqlalchemy import pool + + return engines.testing_engine( + options=dict(poolclass=pool.StaticPool, scope="class"), + ) + else: + return config.db + + @classmethod + def define_tables(cls, metadata): + cls.define_reflected_tables(metadata, None) + if testing.requires.schemas.enabled: + cls.define_reflected_tables(metadata, testing.config.test_schema) + + @classmethod + def define_reflected_tables(cls, metadata, schema): + if schema: + schema_prefix = schema + "." + else: + schema_prefix = "" + + if testing.requires.self_referential_foreign_keys.enabled: + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + Column( + "parent_user_id", + sa.Integer, + sa.ForeignKey( + "%susers.user_id" % schema_prefix, name="user_id_fk" + ), + ), + schema=schema, + test_needs_fk=True, + ) + else: + users = Table( + "users", + metadata, + Column("user_id", sa.INT, primary_key=True), + Column("test1", sa.CHAR(5), nullable=False), + Column("test2", sa.Float(5), nullable=False), + schema=schema, + test_needs_fk=True, + ) + + Table( + "dingalings", + metadata, + Column("dingaling_id", sa.Integer, primary_key=True), + Column( + "address_id", + sa.Integer, + sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + ), + Column("data", sa.String(30)), + schema=schema, + test_needs_fk=True, + ) + Table( + "email_addresses", + metadata, + Column("address_id", sa.Integer), + Column( + "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) + ), + Column("email_address", sa.String(20)), + sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), + schema=schema, + test_needs_fk=True, + ) + Table( + "comment_test", + metadata, + Column("id", sa.Integer, primary_key=True, comment="id comment"), + Column("data", sa.String(20), comment="data % comment"), + Column( + "d2", + sa.String(20), + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + schema=schema, + comment=r"""the test % ' " \ table comment""", + ) + + if testing.requires.cross_schema_fk_reflection.enabled: + if schema is None: + Table( + "local_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), + Column( + "remote_id", + ForeignKey( + "%s.remote_table_2.id" % testing.config.test_schema + ), + ), + test_needs_fk=True, + schema=config.db.dialect.default_schema_name, + ) + else: + Table( + "remote_table", + metadata, + Column("id", sa.Integer, primary_key=True), + Column( + "local_id", + ForeignKey( + "%s.local_table.id" + % config.db.dialect.default_schema_name + ), + ), + Column("data", sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + Table( + "remote_table_2", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("data", sa.String(20)), + schema=schema, + test_needs_fk=True, + ) + + if testing.requires.index_reflection.enabled: + cls.define_index(metadata, users) + + if not schema: + # test_needs_fk is at the moment to force MySQL InnoDB + noncol_idx_test_nopk = Table( + "noncol_idx_test_nopk", + metadata, + Column("q", sa.String(5)), + test_needs_fk=True, + ) + + noncol_idx_test_pk = Table( + "noncol_idx_test_pk", + metadata, + Column("id", sa.Integer, primary_key=True), + Column("q", sa.String(5)), + test_needs_fk=True, + ) + + if testing.requires.indexes_with_ascdesc.enabled: + Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) + Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) + + if testing.requires.view_column_reflection.enabled: + cls.define_views(metadata, schema) + if not schema and testing.requires.temp_table_reflection.enabled: + cls.define_temp_tables(metadata) + + @classmethod + def define_temp_tables(cls, metadata): + kw = temp_table_keyword_args(config, config.db) + table_name = get_temp_table_name( + config, config.db, "user_tmp_%s" % config.ident + ) + user_tmp = Table( + table_name, + metadata, + Column("id", sa.INT, primary_key=True), + Column("name", sa.VARCHAR(50)), + Column("foo", sa.INT), + # disambiguate temp table unique constraint names. this is + # pretty arbitrary for a generic dialect however we are doing + # it to suit SQL Server which will produce name conflicts for + # unique constraints created against temp tables in different + # databases. + # https://www.arbinada.com/en/node/1645 + sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident), + sa.Index("user_tmp_ix", "foo"), + **kw + ) + if ( + testing.requires.view_reflection.enabled + and testing.requires.temporary_views.enabled + ): + event.listen( + user_tmp, + "after_create", + DDL( + "create temporary view user_tmp_v as " + "select * from user_tmp_%s" % config.ident + ), + ) + event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) + + @classmethod + def define_index(cls, metadata, users): + Index("users_t_idx", users.c.test1, users.c.test2) + Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1) + + @classmethod + def define_views(cls, metadata, schema): + for table_name in ("users", "email_addresses"): + fullname = table_name + if schema: + fullname = "%s.%s" % (schema, table_name) + view_name = fullname + "_v" + query = "CREATE VIEW %s AS SELECT * FROM %s" % ( + view_name, + fullname, + ) + + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, "before_drop", DDL("DROP VIEW %s" % view_name) + ) + + @testing.requires.schema_reflection + def test_get_schema_names(self): + insp = inspect(self.bind) + + self.assert_(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_get_schema_names_w_translate_map(self, connection): + """test #7300""" + + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + self.assert_(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_dialect_initialize(self): + engine = engines.testing_engine() + inspect(engine) + assert hasattr(engine.dialect, "default_schema_name") + + @testing.requires.schema_reflection + def test_get_default_schema_name(self): + insp = inspect(self.bind) + eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + + @testing.requires.foreign_key_constraint_reflection + @testing.combinations( + (None, True, False, False), + (None, True, False, True, testing.requires.schemas), + ("foreign_key", True, False, False), + (None, False, True, False), + (None, False, True, True, testing.requires.schemas), + (None, True, True, False), + (None, True, True, True, testing.requires.schemas), + argnames="order_by,include_plain,include_views,use_schema", + ) + def test_get_table_names( + self, connection, order_by, include_plain, include_views, use_schema + ): + + if use_schema: + schema = config.test_schema + else: + schema = None + + _ignore_tables = [ + "comment_test", + "noncol_idx_test_pk", + "noncol_idx_test_nopk", + "local_table", + "remote_table", + "remote_table_2", + ] + + insp = inspect(connection) + + if include_views: + table_names = insp.get_view_names(schema) + table_names.sort() + answer = ["email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) + + if include_plain: + if order_by: + tables = [ + rec[0] + for rec in insp.get_sorted_table_and_fkc_names(schema) + if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] + + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] + eq_(table_names, answer) + else: + answer = ["dingalings", "email_addresses", "users"] + eq_(sorted(table_names), answer) + + @testing.requires.temp_table_names + def test_get_temp_table_names(self): + insp = inspect(self.bind) + temp_table_names = insp.get_temp_table_names() + eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident]) + + @testing.requires.view_reflection + @testing.requires.temp_table_names + @testing.requires.temporary_views + def test_get_temp_view_names(self): + insp = inspect(self.bind) + temp_table_names = insp.get_temp_view_names() + eq_(sorted(temp_table_names), ["user_tmp_v"]) + + @testing.requires.comment_reflection + def test_get_comments(self): + self._test_get_comments() + + @testing.requires.comment_reflection + @testing.requires.schemas + def test_get_comments_with_schema(self): + self._test_get_comments(testing.config.test_schema) + + def _test_get_comments(self, schema=None): + insp = inspect(self.bind) + + eq_( + insp.get_table_comment("comment_test", schema=schema), + {"text": r"""the test % ' " \ table comment"""}, + ) + + eq_(insp.get_table_comment("users", schema=schema), {"text": None}) + + eq_( + [ + {"name": rec["name"], "comment": rec["comment"]} + for rec in insp.get_columns("comment_test", schema=schema) + ], + [ + {"comment": "id comment", "name": "id"}, + {"comment": "data % comment", "name": "data"}, + { + "comment": ( + r"""Comment types type speedily ' " \ '' Fun!""" + ), + "name": "d2", + }, + ], + ) + + @testing.combinations( + (False, False), + (False, True, testing.requires.schemas), + (True, False, testing.requires.view_reflection), + ( + True, + True, + testing.requires.schemas + testing.requires.view_reflection, + ), + argnames="use_views,use_schema", + ) + def test_get_columns(self, connection, use_views, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + if use_views: + table_names = ["users_v", "email_addresses_v"] + else: + table_names = ["users", "email_addresses"] + + insp = inspect(connection) + for table_name, table in zip(table_names, (users, addresses)): + schema_name = schema + cols = insp.get_columns(table_name, schema=schema_name) + self.assert_(len(cols) > 0, len(cols)) + + # should be in order + + for i, col in enumerate(table.columns): + eq_(col.name, cols[i]["name"]) + ctype = cols[i]["type"].__class__ + ctype_def = col.type + if isinstance(ctype_def, sa.types.TypeEngine): + ctype_def = ctype_def.__class__ + + # Oracle returns Date for DateTime. + + if testing.against("oracle") and ctype_def in ( + sql_types.Date, + sql_types.DateTime, + ): + ctype_def = sql_types.Date + + # assert that the desired type and return type share + # a base within one of the generic types. + + self.assert_( + len( + set(ctype.__mro__) + .intersection(ctype_def.__mro__) + .intersection( + [ + sql_types.Integer, + sql_types.Numeric, + sql_types.DateTime, + sql_types.Date, + sql_types.Time, + sql_types.String, + sql_types._Binary, + ] + ) + ) + > 0, + "%s(%s), %s(%s)" + % (col.name, col.type, cols[i]["name"], ctype), + ) + + if not col.primary_key: + assert cols[i]["default"] is None + + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self): + table_name = get_temp_table_name( + config, self.bind, "user_tmp_%s" % config.ident + ) + user_tmp = self.tables[table_name] + insp = inspect(self.bind) + cols = insp.get_columns(table_name) + self.assert_(len(cols) > 0, len(cols)) + + for i, col in enumerate(user_tmp.columns): + eq_(col.name, cols[i]["name"]) + + @testing.requires.temp_table_reflection + @testing.requires.view_column_reflection + @testing.requires.temporary_views + def test_get_temp_view_columns(self): + insp = inspect(self.bind) + cols = insp.get_columns("user_tmp_v") + eq_([col["name"] for col in cols], ["id", "name", "foo"]) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.primary_key_constraint_reflection + def test_get_pk_constraint(self, connection, use_schema): + if use_schema: + schema = testing.config.test_schema + else: + schema = None + + users, addresses = self.tables.users, self.tables.email_addresses + insp = inspect(connection) + + users_cons = insp.get_pk_constraint(users.name, schema=schema) + users_pkeys = users_cons["constrained_columns"] + eq_(users_pkeys, ["user_id"]) + + addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) + addr_pkeys = addr_cons["constrained_columns"] + eq_(addr_pkeys, ["address_id"]) + + with testing.requires.reflects_pk_names.fail_if(): + eq_(addr_cons["name"], "email_ad_pk") + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + @testing.requires.foreign_key_constraint_reflection + def test_get_foreign_keys(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + users, addresses = (self.tables.users, self.tables.email_addresses) + insp = inspect(connection) + expected_schema = schema + # users + + if testing.requires.self_referential_foreign_keys.enabled: + users_fkeys = insp.get_foreign_keys(users.name, schema=schema) + fkey1 = users_fkeys[0] + + with testing.requires.named_constraints.fail_if(): + eq_(fkey1["name"], "user_id_fk") + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + if testing.requires.self_referential_foreign_keys.enabled: + eq_(fkey1["constrained_columns"], ["parent_user_id"]) + + # addresses + addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) + fkey1 = addr_fkeys[0] + + with testing.requires.implicitly_named_constraints.fail_if(): + self.assert_(fkey1["name"] is not None) + + eq_(fkey1["referred_schema"], expected_schema) + eq_(fkey1["referred_table"], users.name) + eq_(fkey1["referred_columns"], ["user_id"]) + eq_(fkey1["constrained_columns"], ["remote_user_id"]) + + @testing.requires.cross_schema_fk_reflection + @testing.requires.schemas + def test_get_inter_schema_foreign_keys(self): + local_table, remote_table, remote_table_2 = self.tables( + "%s.local_table" % self.bind.dialect.default_schema_name, + "%s.remote_table" % testing.config.test_schema, + "%s.remote_table_2" % testing.config.test_schema, + ) + + insp = inspect(self.bind) + + local_fkeys = insp.get_foreign_keys(local_table.name) + eq_(len(local_fkeys), 1) + + fkey1 = local_fkeys[0] + eq_(fkey1["referred_schema"], testing.config.test_schema) + eq_(fkey1["referred_table"], remote_table_2.name) + eq_(fkey1["referred_columns"], ["id"]) + eq_(fkey1["constrained_columns"], ["remote_id"]) + + remote_fkeys = insp.get_foreign_keys( + remote_table.name, schema=testing.config.test_schema + ) + eq_(len(remote_fkeys), 1) + + fkey2 = remote_fkeys[0] + + assert fkey2["referred_schema"] in ( + None, + self.bind.dialect.default_schema_name, + ) + eq_(fkey2["referred_table"], local_table.name) + eq_(fkey2["referred_columns"], ["id"]) + eq_(fkey2["constrained_columns"], ["local_id"]) + + def _assert_insp_indexes(self, indexes, expected_indexes): + index_names = [d["name"] for d in indexes] + for e_index in expected_indexes: + assert e_index["name"] in index_names + index = indexes[index_names.index(e_index["name"])] + for key in e_index: + eq_(e_index[key], index[key]) + + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_indexes(self, connection, use_schema): + + if use_schema: + schema = config.test_schema + else: + schema = None + + # The database may decide to create indexes for foreign keys, etc. + # so there may be more indexes than expected. + insp = inspect(self.bind) + indexes = insp.get_indexes("users", schema=schema) + expected_indexes = [ + { + "unique": False, + "column_names": ["test1", "test2"], + "name": "users_t_idx", + }, + { + "unique": False, + "column_names": ["user_id", "test2", "test1"], + "name": "users_all_idx", + }, + ] + self._assert_insp_indexes(indexes, expected_indexes) + + @testing.combinations( + ("noncol_idx_test_nopk", "noncol_idx_nopk"), + ("noncol_idx_test_pk", "noncol_idx_pk"), + argnames="tname,ixname", + ) + @testing.requires.index_reflection + @testing.requires.indexes_with_ascdesc + def test_get_noncol_index(self, connection, tname, ixname): + insp = inspect(connection) + indexes = insp.get_indexes(tname) + + # reflecting an index that has "x DESC" in it as the column. + # the DB may or may not give us "x", but make sure we get the index + # back, it has a name, it's connected to the table. + expected_indexes = [{"unique": False, "name": ixname}] + self._assert_insp_indexes(indexes, expected_indexes) + + t = Table(tname, MetaData(), autoload_with=connection) + eq_(len(t.indexes), 1) + is_(list(t.indexes)[0].table, t) + eq_(list(t.indexes)[0].name, ixname) + + @testing.requires.temp_table_reflection + @testing.requires.unique_constraint_reflection + def test_get_temp_table_unique_constraints(self): + insp = inspect(self.bind) + reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident) + for refl in reflected: + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + refl.pop("duplicates_index", None) + eq_( + reflected, + [ + { + "column_names": ["name"], + "name": "user_tmp_uq_%s" % config.ident, + } + ], + ) + + @testing.requires.temp_table_reflect_indexes + def test_get_temp_table_indexes(self): + insp = inspect(self.bind) + table_name = get_temp_table_name( + config, config.db, "user_tmp_%s" % config.ident + ) + indexes = insp.get_indexes(table_name) + for ind in indexes: + ind.pop("dialect_options", None) + expected = [ + {"unique": False, "column_names": ["foo"], "name": "user_tmp_ix"} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + eq_( + [idx for idx in indexes if idx["name"] == "user_tmp_ix"], + expected, + ) + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + @testing.requires.unique_constraint_reflection + def test_get_unique_constraints(self, metadata, connection, use_schema): + # SQLite dialect needs to parse the names of the constraints + # separately from what it gets from PRAGMA index_list(), and + # then matches them up. so same set of column_names in two + # constraints will confuse it. Perhaps we should no longer + # bother with index_list() here since we have the whole + # CREATE TABLE? + + if use_schema: + schema = config.test_schema + else: + schema = None + uniques = sorted( + [ + {"name": "unique_a", "column_names": ["a"]}, + {"name": "unique_a_b_c", "column_names": ["a", "b", "c"]}, + {"name": "unique_c_a_b", "column_names": ["c", "a", "b"]}, + {"name": "unique_asc_key", "column_names": ["asc", "key"]}, + {"name": "i.have.dots", "column_names": ["b"]}, + {"name": "i have spaces", "column_names": ["c"]}, + ], + key=operator.itemgetter("name"), + ) + table = Table( + "testtbl", + metadata, + Column("a", sa.String(20)), + Column("b", sa.String(30)), + Column("c", sa.Integer), + # reserved identifiers + Column("asc", sa.String(30)), + Column("key", sa.String(30)), + schema=schema, + ) + for uc in uniques: + table.append_constraint( + sa.UniqueConstraint(*uc["column_names"], name=uc["name"]) + ) + table.create(connection) + + inspector = inspect(connection) + reflected = sorted( + inspector.get_unique_constraints("testtbl", schema=schema), + key=operator.itemgetter("name"), + ) + + names_that_duplicate_index = set() + + for orig, refl in zip(uniques, reflected): + # Different dialects handle duplicate index and constraints + # differently, so ignore this flag + dupe = refl.pop("duplicates_index", None) + if dupe: + names_that_duplicate_index.add(dupe) + eq_(orig, refl) + + reflected_metadata = MetaData() + reflected = Table( + "testtbl", + reflected_metadata, + autoload_with=connection, + schema=schema, + ) + + # test "deduplicates for index" logic. MySQL and Oracle + # "unique constraints" are actually unique indexes (with possible + # exception of a unique that is a dupe of another one in the case + # of Oracle). make sure # they aren't duplicated. + idx_names = set([idx.name for idx in reflected.indexes]) + uq_names = set( + [ + uq.name + for uq in reflected.constraints + if isinstance(uq, sa.UniqueConstraint) + ] + ).difference(["unique_c_a_b"]) + + assert not idx_names.intersection(uq_names) + if names_that_duplicate_index: + eq_(names_that_duplicate_index, idx_names) + eq_(uq_names, set()) + + @testing.requires.view_reflection + @testing.combinations( + (False,), (True, testing.requires.schemas), argnames="use_schema" + ) + def test_get_view_definition(self, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + view_name1 = "users_v" + view_name2 = "email_addresses_v" + insp = inspect(connection) + v1 = insp.get_view_definition(view_name1, schema=schema) + self.assert_(v1) + v2 = insp.get_view_definition(view_name2, schema=schema) + self.assert_(v2) + + # why is this here if it's PG specific ? + @testing.combinations( + ("users", False), + ("users", True, testing.requires.schemas), + argnames="table_name,use_schema", + ) + @testing.only_on("postgresql", "PG specific feature") + def test_get_table_oid(self, connection, table_name, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + insp = inspect(connection) + oid = insp.get_table_oid(table_name, schema) + self.assert_(isinstance(oid, int)) + + @testing.requires.table_reflection + def test_autoincrement_col(self): + """test that 'autoincrement' is reflected according to sqla's policy. + + Don't mark this test as unsupported for any backend ! + + (technically it fails with MySQL InnoDB since "id" comes before "id2") + + A backend is better off not returning "autoincrement" at all, + instead of potentially returning "False" for an auto-incrementing + primary key column. + + """ + + insp = inspect(self.bind) + + for tname, cname in [ + ("users", "user_id"), + ("email_addresses", "address_id"), + ("dingalings", "dingaling_id"), + ]: + cols = insp.get_columns(tname) + id_ = {c["name"]: c for c in cols}[cname] + assert id_.get("autoincrement", True) + + +class TableNoColumnsTest(fixtures.TestBase): + __requires__ = ("reflect_tables_no_columns",) + __backend__ = True + + @testing.fixture + def table_no_columns(self, connection, metadata): + Table("empty", metadata) + metadata.create_all(connection) + + @testing.fixture + def view_no_columns(self, connection, metadata): + Table("empty", metadata) + metadata.create_all(connection) + + Table("empty", metadata) + event.listen( + metadata, + "after_create", + DDL("CREATE VIEW empty_v AS SELECT * FROM empty"), + ) + + # for transactional DDL the transaction is rolled back before this + # drop statement is invoked + event.listen( + metadata, "before_drop", DDL("DROP VIEW IF EXISTS empty_v") + ) + metadata.create_all(connection) + + @testing.requires.reflect_tables_no_columns + def test_reflect_table_no_columns(self, connection, table_no_columns): + t2 = Table("empty", MetaData(), autoload_with=connection) + eq_(list(t2.c), []) + + @testing.requires.reflect_tables_no_columns + def test_get_columns_table_no_columns(self, connection, table_no_columns): + eq_(inspect(connection).get_columns("empty"), []) + + @testing.requires.reflect_tables_no_columns + def test_reflect_incl_table_no_columns(self, connection, table_no_columns): + m = MetaData() + m.reflect(connection) + assert set(m.tables).intersection(["empty"]) + + @testing.requires.views + @testing.requires.reflect_tables_no_columns + def test_reflect_view_no_columns(self, connection, view_no_columns): + t2 = Table("empty_v", MetaData(), autoload_with=connection) + eq_(list(t2.c), []) + + @testing.requires.views + @testing.requires.reflect_tables_no_columns + def test_get_columns_view_no_columns(self, connection, view_no_columns): + eq_(inspect(connection).get_columns("empty_v"), []) + + +class ComponentReflectionTestExtra(fixtures.TestBase): + + __backend__ = True + + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + @testing.requires.check_constraint_reflection + def test_get_check_constraints(self, metadata, connection, use_schema): + if use_schema: + schema = config.test_schema + else: + schema = None + + Table( + "sa_cc", + metadata, + Column("a", Integer()), + sa.CheckConstraint("a > 1 AND a < 5", name="cc1"), + sa.CheckConstraint( + "a = 1 OR (a > 2 AND a < 5)", name="UsesCasing" + ), + schema=schema, + ) + + metadata.create_all(connection) + + inspector = inspect(connection) + reflected = sorted( + inspector.get_check_constraints("sa_cc", schema=schema), + key=operator.itemgetter("name"), + ) + + # trying to minimize effect of quoting, parenthesis, etc. + # may need to add more to this as new dialects get CHECK + # constraint reflection support + def normalize(sqltext): + return " ".join( + re.findall(r"and|\d|=|a|or|<|>", sqltext.lower(), re.I) + ) + + reflected = [ + {"name": item["name"], "sqltext": normalize(item["sqltext"])} + for item in reflected + ] + eq_( + reflected, + [ + {"name": "UsesCasing", "sqltext": "a = 1 or a > 2 and a < 5"}, + {"name": "cc1", "sqltext": "a > 1 and a < 5"}, + ], + ) + + @testing.requires.indexes_with_expressions + def test_reflect_expression_based_indexes(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + + Index("t_idx", func.lower(t.c.x), func.lower(t.c.y)) + + Index("t_idx_2", t.c.x) + + metadata.create_all(connection) + + insp = inspect(connection) + + expected = [ + {"name": "t_idx_2", "column_names": ["x"], "unique": False} + ] + if testing.requires.index_reflects_included_columns.enabled: + expected[0]["include_columns"] = [] + expected[0]["dialect_options"] = { + "%s_include" % connection.engine.name: [] + } + + with expect_warnings( + "Skipped unsupported reflection of expression-based index t_idx" + ): + eq_( + insp.get_indexes("t"), + expected, + ) + + @testing.requires.index_reflects_included_columns + def test_reflect_covering_index(self, metadata, connection): + t = Table( + "t", + metadata, + Column("x", String(30)), + Column("y", String(30)), + ) + idx = Index("t_idx", t.c.x) + idx.dialect_options[connection.engine.name]["include"] = ["y"] + + metadata.create_all(connection) + + insp = inspect(connection) + + eq_( + insp.get_indexes("t"), + [ + { + "name": "t_idx", + "column_names": ["x"], + "include_columns": ["y"], + "unique": False, + "dialect_options": { + "%s_include" % connection.engine.name: ["y"] + }, + } + ], + ) + + t2 = Table("t", MetaData(), autoload_with=connection) + eq_( + list(t2.indexes)[0].dialect_options[connection.engine.name][ + "include" + ], + ["y"], + ) + + def _type_round_trip(self, connection, metadata, *types): + t = Table( + "t", + metadata, + *[Column("t%d" % i, type_) for i, type_ in enumerate(types)] + ) + t.create(connection) + + return [c["type"] for c in inspect(connection).get_columns("t")] + + @testing.requires.table_reflection + def test_numeric_reflection(self, connection, metadata): + for typ in self._type_round_trip( + connection, metadata, sql_types.Numeric(18, 5) + ): + assert isinstance(typ, sql_types.Numeric) + eq_(typ.precision, 18) + eq_(typ.scale, 5) + + @testing.requires.table_reflection + def test_varchar_reflection(self, connection, metadata): + typ = self._type_round_trip( + connection, metadata, sql_types.String(52) + )[0] + assert isinstance(typ, sql_types.String) + eq_(typ.length, 52) + + @testing.requires.table_reflection + def test_nullable_reflection(self, connection, metadata): + t = Table( + "t", + metadata, + Column("a", Integer, nullable=True), + Column("b", Integer, nullable=False), + ) + t.create(connection) + eq_( + dict( + (col["name"], col["nullable"]) + for col in inspect(connection).get_columns("t") + ), + {"a": True, "b": False}, + ) + + @testing.combinations( + ( + None, + "CASCADE", + None, + testing.requires.foreign_key_constraint_option_reflection_ondelete, + ), + ( + None, + None, + "SET NULL", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + None, + "NO ACTION", + testing.requires.foreign_key_constraint_option_reflection_onupdate, + ), + ( + {}, + "NO ACTION", + None, + testing.requires.fk_constraint_option_reflection_ondelete_noaction, + ), + ( + None, + None, + "RESTRICT", + testing.requires.fk_constraint_option_reflection_onupdate_restrict, + ), + ( + None, + "RESTRICT", + None, + testing.requires.fk_constraint_option_reflection_ondelete_restrict, + ), + argnames="expected,ondelete,onupdate", + ) + def test_get_foreign_key_options( + self, connection, metadata, expected, ondelete, onupdate + ): + options = {} + if ondelete: + options["ondelete"] = ondelete + if onupdate: + options["onupdate"] = onupdate + + if expected is None: + expected = options + + Table( + "x", + metadata, + Column("id", Integer, primary_key=True), + test_needs_fk=True, + ) + + Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("x_id", Integer, sa.ForeignKey("x.id", name="xid")), + Column("test", String(10)), + test_needs_fk=True, + ) + + Table( + "user", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(50), nullable=False), + Column("tid", Integer), + sa.ForeignKeyConstraint( + ["tid"], ["table.id"], name="myfk", **options + ), + test_needs_fk=True, + ) + + metadata.create_all(connection) + + insp = inspect(connection) + + # test 'options' is always present for a backend + # that can reflect these, since alembic looks for this + opts = insp.get_foreign_keys("table")[0]["options"] + + eq_(dict((k, opts[k]) for k in opts if opts[k]), {}) + + opts = insp.get_foreign_keys("user")[0]["options"] + eq_(opts, expected) + # eq_(dict((k, opts[k]) for k in opts if opts[k]), expected) + + +class NormalizedNameTest(fixtures.TablesTest): + __requires__ = ("denormalized_names",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + quoted_name("t1", quote=True), + metadata, + Column("id", Integer, primary_key=True), + ) + Table( + quoted_name("t2", quote=True), + metadata, + Column("id", Integer, primary_key=True), + Column("t1id", ForeignKey("t1.id")), + ) + + def test_reflect_lowercase_forced_tables(self): + + m2 = MetaData() + t2_ref = Table( + quoted_name("t2", quote=True), m2, autoload_with=config.db + ) + t1_ref = m2.tables["t1"] + assert t2_ref.c.t1id.references(t1_ref.c.id) + + m3 = MetaData() + m3.reflect( + config.db, only=lambda name, m: name.lower() in ("t1", "t2") + ) + assert m3.tables["t2"].c.t1id.references(m3.tables["t1"].c.id) + + def test_get_table_names(self): + tablenames = [ + t + for t in inspect(config.db).get_table_names() + if t.lower() in ("t1", "t2") + ] + + eq_(tablenames[0].upper(), tablenames[0].lower()) + eq_(tablenames[1].upper(), tablenames[1].lower()) + + +class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest): + def test_computed_col_default_not_set(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + col_data = {c["name"]: c for c in cols} + is_true("42" in col_data["with_default"]["default"]) + is_(col_data["normal"]["default"], None) + is_(col_data["computed_col"]["default"], None) + + def test_get_column_returns_computed(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_default_table") + data = {c["name"]: c for c in cols} + for key in ("id", "normal", "with_default"): + is_true("computed" not in data[key]) + compData = data["computed_col"] + is_true("computed" in compData) + is_true("sqltext" in compData["computed"]) + eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42") + eq_( + "persisted" in compData["computed"], + testing.requires.computed_columns_reflect_persisted.enabled, + ) + if testing.requires.computed_columns_reflect_persisted.enabled: + eq_( + compData["computed"]["persisted"], + testing.requires.computed_columns_default_persisted.enabled, + ) + + def check_column(self, data, column, sqltext, persisted): + is_true("computed" in data[column]) + compData = data[column]["computed"] + eq_(self.normalize(compData["sqltext"]), sqltext) + if testing.requires.computed_columns_reflect_persisted.enabled: + is_true("persisted" in compData) + is_(compData["persisted"], persisted) + + def test_get_column_returns_persisted(self): + insp = inspect(config.db) + + cols = insp.get_columns("computed_column_table") + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal+42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal+2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal-42", + True, + ) + + @testing.requires.schemas + def test_get_column_returns_persisted_with_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns( + "computed_column_table", schema=config.test_schema + ) + data = {c["name"]: c for c in cols} + + self.check_column( + data, + "computed_no_flag", + "normal/42", + testing.requires.computed_columns_default_persisted.enabled, + ) + if testing.requires.computed_columns_virtual.enabled: + self.check_column( + data, + "computed_virtual", + "normal/2", + False, + ) + if testing.requires.computed_columns_stored.enabled: + self.check_column( + data, + "computed_stored", + "normal*42", + True, + ) + + +class IdentityReflectionTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("identity_columns", "table_reflection") + + @classmethod + def define_tables(cls, metadata): + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity()), + ) + Table( + "t2", + metadata, + Column( + "id2", + Integer, + Identity( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + ), + ) + if testing.requires.schemas.enabled: + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity(always=True, start=20)), + schema=config.test_schema, + ) + + def check(self, value, exp, approx): + if testing.requires.identity_columns_standard.enabled: + common_keys = ( + "always", + "start", + "increment", + "minvalue", + "maxvalue", + "cycle", + "cache", + ) + for k in list(value): + if k not in common_keys: + value.pop(k) + if approx: + eq_(len(value), len(exp)) + for k in value: + if k == "minvalue": + is_true(value[k] <= exp[k]) + elif k in {"maxvalue", "cache"}: + is_true(value[k] >= exp[k]) + else: + eq_(value[k], exp[k], k) + else: + eq_(value, exp) + else: + eq_(value["start"], exp["start"]) + eq_(value["increment"], exp["increment"]) + + def test_reflect_identity(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1") + insp.get_columns("t2") + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=False, + start=1, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + elif col["name"] == "id2": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + approx=False, + ) + + @testing.requires.schemas + def test_reflect_identity_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1", schema=config.test_schema) + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=20, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + + +class CompositeKeyReflectionTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + tb1 = Table( + "tb1", + metadata, + Column("id", Integer), + Column("attr", Integer), + Column("name", sql_types.VARCHAR(20)), + sa.PrimaryKeyConstraint("name", "id", "attr", name="pk_tb1"), + schema=None, + test_needs_fk=True, + ) + Table( + "tb2", + metadata, + Column("id", Integer, primary_key=True), + Column("pid", Integer), + Column("pattr", Integer), + Column("pname", sql_types.VARCHAR(20)), + sa.ForeignKeyConstraint( + ["pname", "pid", "pattr"], + [tb1.c.name, tb1.c.id, tb1.c.attr], + name="fk_tb1_name_id_attr", + ), + schema=None, + test_needs_fk=True, + ) + + @testing.requires.primary_key_constraint_reflection + def test_pk_column_order(self): + # test for issue #5661 + insp = inspect(self.bind) + primary_key = insp.get_pk_constraint(self.tables.tb1.name) + eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) + + @testing.requires.foreign_key_constraint_reflection + def test_fk_column_order(self): + # test for issue #5661 + insp = inspect(self.bind) + foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) + eq_(len(foreign_keys), 1) + fkey1 = foreign_keys[0] + eq_(fkey1.get("referred_columns"), ["name", "id", "attr"]) + eq_(fkey1.get("constrained_columns"), ["pname", "pid", "pattr"]) + + +__all__ = ( + "ComponentReflectionTest", + "ComponentReflectionTestExtra", + "TableNoColumnsTest", + "QuotedNameArgumentTest", + "HasTableTest", + "HasIndexTest", + "NormalizedNameTest", + "ComputedReflectionTest", + "IdentityReflectionTest", + "CompositeKeyReflectionTest", +) diff --git a/lib/sqlalchemy/testing/suite/test_results.py b/lib/sqlalchemy/testing/suite/test_results.py new file mode 100644 index 0000000..c41a550 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_results.py @@ -0,0 +1,426 @@ +import datetime + +from .. import engines +from .. import fixtures +from ..assertions import eq_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import DateTime +from ... import func +from ... import Integer +from ... import select +from ... import sql +from ... import String +from ... import testing +from ... import text +from ... import util + + +class RowFetchTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + Table( + "has_dates", + metadata, + Column("id", Integer, primary_key=True), + Column("today", DateTime), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + connection.execute( + cls.tables.has_dates.insert(), + [{"id": 1, "today": datetime.datetime(2006, 5, 12, 12, 0, 0)}], + ) + + def test_via_attr(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row.id, 1) + eq_(row.data, "d1") + + def test_via_string(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row._mapping["id"], 1) + eq_(row._mapping["data"], "d1") + + def test_via_int(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row[0], 1) + eq_(row[1], "d1") + + def test_via_col_object(self, connection): + row = connection.execute( + self.tables.plain_pk.select().order_by(self.tables.plain_pk.c.id) + ).first() + + eq_(row._mapping[self.tables.plain_pk.c.id], 1) + eq_(row._mapping[self.tables.plain_pk.c.data], "d1") + + @requirements.duplicate_names_in_cursor_description + def test_row_with_dupe_names(self, connection): + result = connection.execute( + select( + self.tables.plain_pk.c.data, + self.tables.plain_pk.c.data.label("data"), + ).order_by(self.tables.plain_pk.c.id) + ) + row = result.first() + eq_(result.keys(), ["data", "data"]) + eq_(row, ("d1", "d1")) + + def test_row_w_scalar_select(self, connection): + """test that a scalar select as a column is returned as such + and that type conversion works OK. + + (this is half a SQLAlchemy Core test and half to catch database + backends that may have unusual behavior with scalar selects.) + + """ + datetable = self.tables.has_dates + s = select(datetable.alias("x").c.today).scalar_subquery() + s2 = select(datetable.c.id, s.label("somelabel")) + row = connection.execute(s2).first() + + eq_(row.somelabel, datetime.datetime(2006, 5, 12, 12, 0, 0)) + + +class PercentSchemaNamesTest(fixtures.TablesTest): + """tests using percent signs, spaces in table and column names. + + This didn't work for PostgreSQL / MySQL drivers for a long time + but is now supported. + + """ + + __requires__ = ("percent_schema_names",) + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + cls.tables.percent_table = Table( + "percent%table", + metadata, + Column("percent%", Integer), + Column("spaces % more spaces", Integer), + ) + cls.tables.lightweight_percent_table = sql.table( + "percent%table", + sql.column("percent%"), + sql.column("spaces % more spaces"), + ) + + def test_single_roundtrip(self, connection): + percent_table = self.tables.percent_table + for params in [ + {"percent%": 5, "spaces % more spaces": 12}, + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ]: + connection.execute(percent_table.insert(), params) + self._assert_table(connection) + + def test_executemany_roundtrip(self, connection): + percent_table = self.tables.percent_table + connection.execute( + percent_table.insert(), {"percent%": 5, "spaces % more spaces": 12} + ) + connection.execute( + percent_table.insert(), + [ + {"percent%": 7, "spaces % more spaces": 11}, + {"percent%": 9, "spaces % more spaces": 10}, + {"percent%": 11, "spaces % more spaces": 9}, + ], + ) + self._assert_table(connection) + + def _assert_table(self, conn): + percent_table = self.tables.percent_table + lightweight_percent_table = self.tables.lightweight_percent_table + + for table in ( + percent_table, + percent_table.alias(), + lightweight_percent_table, + lightweight_percent_table.alias(), + ): + eq_( + list( + conn.execute(table.select().order_by(table.c["percent%"])) + ), + [(5, 12), (7, 11), (9, 10), (11, 9)], + ) + + eq_( + list( + conn.execute( + table.select() + .where(table.c["spaces % more spaces"].in_([9, 10])) + .order_by(table.c["percent%"]) + ) + ), + [(9, 10), (11, 9)], + ) + + row = conn.execute( + table.select().order_by(table.c["percent%"]) + ).first() + eq_(row._mapping["percent%"], 5) + eq_(row._mapping["spaces % more spaces"], 12) + + eq_(row._mapping[table.c["percent%"]], 5) + eq_(row._mapping[table.c["spaces % more spaces"]], 12) + + conn.execute( + percent_table.update().values( + {percent_table.c["spaces % more spaces"]: 15} + ) + ) + + eq_( + list( + conn.execute( + percent_table.select().order_by( + percent_table.c["percent%"] + ) + ) + ), + [(5, 15), (7, 15), (9, 15), (11, 15)], + ) + + +class ServerSideCursorsTest( + fixtures.TestBase, testing.AssertsExecutionResults +): + + __requires__ = ("server_side_cursors",) + + __backend__ = True + + def _is_server_side(self, cursor): + # TODO: this is a huge issue as it prevents these tests from being + # usable by third party dialects. + if self.engine.dialect.driver == "psycopg2": + return bool(cursor.name) + elif self.engine.dialect.driver == "pymysql": + sscursor = __import__("pymysql.cursors").cursors.SSCursor + return isinstance(cursor, sscursor) + elif self.engine.dialect.driver in ("aiomysql", "asyncmy"): + return cursor.server_side + elif self.engine.dialect.driver == "mysqldb": + sscursor = __import__("MySQLdb.cursors").cursors.SSCursor + return isinstance(cursor, sscursor) + elif self.engine.dialect.driver == "mariadbconnector": + return not cursor.buffered + elif self.engine.dialect.driver in ("asyncpg", "aiosqlite"): + return cursor.server_side + elif self.engine.dialect.driver == "pg8000": + return getattr(cursor, "server_side", False) + else: + return False + + def _fixture(self, server_side_cursors): + if server_side_cursors: + with testing.expect_deprecated( + "The create_engine.server_side_cursors parameter is " + "deprecated and will be removed in a future release. " + "Please use the Connection.execution_options.stream_results " + "parameter." + ): + self.engine = engines.testing_engine( + options={"server_side_cursors": server_side_cursors} + ) + else: + self.engine = engines.testing_engine( + options={"server_side_cursors": server_side_cursors} + ) + return self.engine + + @testing.combinations( + ("global_string", True, "select 1", True), + ("global_text", True, text("select 1"), True), + ("global_expr", True, select(1), True), + ("global_off_explicit", False, text("select 1"), False), + ( + "stmt_option", + False, + select(1).execution_options(stream_results=True), + True, + ), + ( + "stmt_option_disabled", + True, + select(1).execution_options(stream_results=False), + False, + ), + ("for_update_expr", True, select(1).with_for_update(), True), + # TODO: need a real requirement for this, or dont use this test + ( + "for_update_string", + True, + "SELECT 1 FOR UPDATE", + True, + testing.skip_if("sqlite"), + ), + ("text_no_ss", False, text("select 42"), False), + ( + "text_ss_option", + False, + text("select 42").execution_options(stream_results=True), + True, + ), + id_="iaaa", + argnames="engine_ss_arg, statement, cursor_ss_status", + ) + def test_ss_cursor_status( + self, engine_ss_arg, statement, cursor_ss_status + ): + engine = self._fixture(engine_ss_arg) + with engine.begin() as conn: + if isinstance(statement, util.string_types): + result = conn.exec_driver_sql(statement) + else: + result = conn.execute(statement) + eq_(self._is_server_side(result.cursor), cursor_ss_status) + result.close() + + def test_conn_option(self): + engine = self._fixture(False) + + with engine.connect() as conn: + # should be enabled for this one + result = conn.execution_options( + stream_results=True + ).exec_driver_sql("select 1") + assert self._is_server_side(result.cursor) + + def test_stmt_enabled_conn_option_disabled(self): + engine = self._fixture(False) + + s = select(1).execution_options(stream_results=True) + + with engine.connect() as conn: + # not this one + result = conn.execution_options(stream_results=False).execute(s) + assert not self._is_server_side(result.cursor) + + def test_aliases_and_ss(self): + engine = self._fixture(False) + s1 = ( + select(sql.literal_column("1").label("x")) + .execution_options(stream_results=True) + .subquery() + ) + + # options don't propagate out when subquery is used as a FROM clause + with engine.begin() as conn: + result = conn.execute(s1.select()) + assert not self._is_server_side(result.cursor) + result.close() + + s2 = select(1).select_from(s1) + with engine.begin() as conn: + result = conn.execute(s2) + assert not self._is_server_side(result.cursor) + result.close() + + def test_roundtrip_fetchall(self, metadata): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + with engine.begin() as connection: + test_table.create(connection, checkfirst=True) + connection.execute(test_table.insert(), dict(data="data1")) + connection.execute(test_table.insert(), dict(data="data2")) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2")], + ) + connection.execute( + test_table.update() + .where(test_table.c.id == 2) + .values(data=test_table.c.data + " updated") + ) + eq_( + connection.execute( + test_table.select().order_by(test_table.c.id) + ).fetchall(), + [(1, "data1"), (2, "data2 updated")], + ) + connection.execute(test_table.delete()) + eq_( + connection.scalar( + select(func.count("*")).select_from(test_table) + ), + 0, + ) + + def test_roundtrip_fetchmany(self, metadata): + md = self.metadata + + engine = self._fixture(True) + test_table = Table( + "test_table", + md, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + with engine.begin() as connection: + test_table.create(connection, checkfirst=True) + connection.execute( + test_table.insert(), + [dict(data="data%d" % i) for i in range(1, 20)], + ) + + result = connection.execute( + test_table.select().order_by(test_table.c.id) + ) + + eq_( + result.fetchmany(5), + [(i, "data%d" % i) for i in range(1, 6)], + ) + eq_( + result.fetchmany(10), + [(i, "data%d" % i) for i in range(6, 16)], + ) + eq_(result.fetchall(), [(i, "data%d" % i) for i in range(16, 20)]) diff --git a/lib/sqlalchemy/testing/suite/test_rowcount.py b/lib/sqlalchemy/testing/suite/test_rowcount.py new file mode 100644 index 0000000..82e831f --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_rowcount.py @@ -0,0 +1,165 @@ +from sqlalchemy import bindparam +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy import text +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures + + +class RowCountTest(fixtures.TablesTest): + """test rowcount functionality""" + + __requires__ = ("sane_rowcount",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "employees", + metadata, + Column( + "employee_id", + Integer, + autoincrement=False, + primary_key=True, + ), + Column("name", String(50)), + Column("department", String(1)), + ) + + @classmethod + def insert_data(cls, connection): + cls.data = data = [ + ("Angela", "A"), + ("Andrew", "A"), + ("Anand", "A"), + ("Bob", "B"), + ("Bobette", "B"), + ("Buffy", "B"), + ("Charlie", "C"), + ("Cynthia", "C"), + ("Chris", "C"), + ] + + employees_table = cls.tables.employees + connection.execute( + employees_table.insert(), + [ + {"employee_id": i, "name": n, "department": d} + for i, (n, d) in enumerate(data) + ], + ) + + def test_basic(self, connection): + employees_table = self.tables.employees + s = select( + employees_table.c.name, employees_table.c.department + ).order_by(employees_table.c.employee_id) + rows = connection.execute(s).fetchall() + + eq_(rows, self.data) + + def test_update_rowcount1(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows changed + department = employees_table.c.department + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "Z"}, + ) + assert r.rowcount == 3 + + def test_update_rowcount2(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 0 rows changed + department = employees_table.c.department + + r = connection.execute( + employees_table.update().where(department == "C"), + {"department": "C"}, + ) + eq_(r.rowcount, 3) + + @testing.requires.sane_rowcount_w_returning + def test_update_rowcount_return_defaults(self, connection): + employees_table = self.tables.employees + + department = employees_table.c.department + stmt = ( + employees_table.update() + .where(department == "C") + .values(name=employees_table.c.department + "Z") + .return_defaults() + ) + + r = connection.execute(stmt) + eq_(r.rowcount, 3) + + def test_raw_sql_rowcount(self, connection): + # test issue #3622, make sure eager rowcount is called for text + result = connection.exec_driver_sql( + "update employees set department='Z' where department='C'" + ) + eq_(result.rowcount, 3) + + def test_text_rowcount(self, connection): + # test issue #3622, make sure eager rowcount is called for text + result = connection.execute( + text("update employees set department='Z' " "where department='C'") + ) + eq_(result.rowcount, 3) + + def test_delete_rowcount(self, connection): + employees_table = self.tables.employees + + # WHERE matches 3, 3 rows deleted + department = employees_table.c.department + r = connection.execute( + employees_table.delete().where(department == "C") + ) + eq_(r.rowcount, 3) + + @testing.requires.sane_multi_rowcount + def test_multi_update_rowcount(self, connection): + employees_table = self.tables.employees + stmt = ( + employees_table.update() + .where(employees_table.c.name == bindparam("emp_name")) + .values(department="C") + ) + + r = connection.execute( + stmt, + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], + ) + + eq_(r.rowcount, 2) + + @testing.requires.sane_multi_rowcount + def test_multi_delete_rowcount(self, connection): + employees_table = self.tables.employees + + stmt = employees_table.delete().where( + employees_table.c.name == bindparam("emp_name") + ) + + r = connection.execute( + stmt, + [ + {"emp_name": "Bob"}, + {"emp_name": "Cynthia"}, + {"emp_name": "nonexistent"}, + ], + ) + + eq_(r.rowcount, 2) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py new file mode 100644 index 0000000..cb78fff --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -0,0 +1,1783 @@ +import itertools + +from .. import AssertsCompiledSQL +from .. import AssertsExecutionResults +from .. import config +from .. import fixtures +from ..assertions import assert_raises +from ..assertions import eq_ +from ..assertions import in_ +from ..assertsql import CursorSQL +from ..schema import Column +from ..schema import Table +from ... import bindparam +from ... import case +from ... import column +from ... import Computed +from ... import exists +from ... import false +from ... import ForeignKey +from ... import func +from ... import Identity +from ... import Integer +from ... import literal +from ... import literal_column +from ... import null +from ... import select +from ... import String +from ... import table +from ... import testing +from ... import text +from ... import true +from ... import tuple_ +from ... import TupleType +from ... import union +from ... import util +from ... import values +from ...exc import DatabaseError +from ...exc import ProgrammingError +from ...util import collections_abc + + +class CollateTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(100)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "collate data1"}, + {"id": 2, "data": "collate data2"}, + ], + ) + + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) + + @testing.requires.order_by_collation + def test_collate_order_by(self): + collation = testing.requires.get_order_by_collation(testing.config) + + self._assert_result( + select(self.tables.some_table).order_by( + self.tables.some_table.c.data.collate(collation).asc() + ), + [(1, "collate data1"), (2, "collate data2")], + ) + + +class OrderByLabelTest(fixtures.TablesTest): + """Test the dialect sends appropriate ORDER BY expressions when + labels are used. + + This essentially exercises the "supports_simple_order_by_label" + setting. + + """ + + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, + {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, + {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, + ], + ) + + def _assert_result(self, select, result): + with config.db.connect() as conn: + eq_(conn.execute(select).fetchall(), result) + + def test_plain(self): + table = self.tables.some_table + lx = table.c.x.label("lx") + self._assert_result(select(lx).order_by(lx), [(1,), (2,), (3,)]) + + def test_composed_int(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select(lx).order_by(lx), [(3,), (5,), (7,)]) + + def test_composed_multiple(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + ly = (func.lower(table.c.q) + table.c.p).label("ly") + self._assert_result( + select(lx, ly).order_by(lx, ly.desc()), + [(3, util.u("q1p3")), (5, util.u("q2p2")), (7, util.u("q3p1"))], + ) + + def test_plain_desc(self): + table = self.tables.some_table + lx = table.c.x.label("lx") + self._assert_result(select(lx).order_by(lx.desc()), [(3,), (2,), (1,)]) + + def test_composed_int_desc(self): + table = self.tables.some_table + lx = (table.c.x + table.c.y).label("lx") + self._assert_result(select(lx).order_by(lx.desc()), [(7,), (5,), (3,)]) + + @testing.requires.group_by_complex_expression + def test_group_by_composed(self): + table = self.tables.some_table + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select(func.count(table.c.id), expr).group_by(expr).order_by(expr) + ) + self._assert_result(stmt, [(1, 3), (1, 5), (1, 7)]) + + +class ValuesExpressionTest(fixtures.TestBase): + __requires__ = ("table_value_constructor",) + + __backend__ = True + + def test_tuples(self, connection): + value_expr = values( + column("id", Integer), column("name", String), name="my_values" + ).data([(1, "name1"), (2, "name2"), (3, "name3")]) + + eq_( + connection.execute(select(value_expr)).all(), + [(1, "name1"), (2, "name2"), (3, "name3")], + ) + + +class FetchLimitOffsetTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + {"id": 5, "x": 4, "y": 6}, + ], + ) + + def _assert_result( + self, connection, select, result, params=(), set_=False + ): + if set_: + query_res = connection.execute(select, params).fetchall() + eq_(len(query_res), len(result)) + eq_(set(query_res), set(result)) + + else: + eq_(connection.execute(select, params).fetchall(), result) + + def _assert_result_str(self, select, result, params=()): + conn = config.db.connect(close_with_result=True) + eq_(conn.exec_driver_sql(select, params).fetchall(), result) + + def test_simple_limit(self, connection): + table = self.tables.some_table + stmt = select(table).order_by(table.c.id) + self._assert_result( + connection, + stmt.limit(2), + [(1, 1, 2), (2, 2, 3)], + ) + self._assert_result( + connection, + stmt.limit(3), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + def test_limit_render_multiple_times(self, connection): + table = self.tables.some_table + stmt = select(table.c.id).limit(1).scalar_subquery() + + u = union(select(stmt), select(stmt)).subquery().select() + + self._assert_result( + connection, + u, + [ + (1,), + ], + ) + + @testing.requires.fetch_first + def test_simple_fetch(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(2), + [(1, 1, 2), (2, 2, 3)], + ) + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(3), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.offset + def test_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.combinations( + ([(2, 0), (2, 1), (3, 2)]), + ([(2, 1), (2, 0), (3, 2)]), + ([(3, 1), (2, 1), (3, 1)]), + argnames="cases", + ) + @testing.requires.offset + def test_simple_limit_offset(self, connection, cases): + table = self.tables.some_table + connection = connection.execution_options(compiled_cache={}) + + assert_data = [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)] + + for limit, offset in cases: + expected = assert_data[offset : offset + limit] + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(limit).offset(offset), + expected, + ) + + @testing.requires.fetch_first + def test_simple_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(2).offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(3).offset(2), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.fetch_no_order_by + def test_fetch_offset_no_order(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).fetch(10), + [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.offset + def test_simple_offset_zero(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(0), + [(1, 1, 2), (2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(1), + [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.offset + def test_limit_offset_nobinds(self): + """test that 'literal binds' mode works - no bound params.""" + + table = self.tables.some_table + stmt = select(table).order_by(table.c.id).limit(2).offset(1) + sql = stmt.compile( + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) + sql = str(sql) + + self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)]) + + @testing.requires.fetch_first + def test_fetch_offset_nobinds(self): + """test that 'literal binds' mode works - no bound params.""" + + table = self.tables.some_table + stmt = select(table).order_by(table.c.id).fetch(2).offset(1) + sql = stmt.compile( + dialect=config.db.dialect, compile_kwargs={"literal_binds": True} + ) + sql = str(sql) + + self._assert_result_str(sql, [(2, 2, 3), (3, 3, 4)]) + + @testing.requires.bound_limit_offset + def test_bound_limit(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(bindparam("l")), + [(1, 1, 2), (2, 2, 3)], + params={"l": 2}, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).limit(bindparam("l")), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + params={"l": 3}, + ) + + @testing.requires.bound_limit_offset + def test_bound_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"o": 2}, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.id).offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"o": 1}, + ) + + @testing.requires.bound_limit_offset + def test_bound_limit_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4)], + params={"l": 2, "o": 1}, + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(bindparam("l")) + .offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"l": 3, "o": 2}, + ) + + @testing.requires.fetch_first + def test_bound_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(bindparam("f")) + .offset(bindparam("o")), + [(2, 2, 3), (3, 3, 4)], + params={"f": 2, "o": 1}, + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(bindparam("f")) + .offset(bindparam("o")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + params={"f": 3, "o": 2}, + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .offset(literal_column("1") + literal_column("2")), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("2")), + [(1, 1, 2), (2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + @testing.requires.fetch_first + @testing.requires.fetch_expression + def test_expr_fetch_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(literal_column("1") + literal_column("1")) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + @testing.requires.sql_expression_limit_offset + def test_simple_limit_expr_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(2) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(3) + .offset(literal_column("1") + literal_column("1")), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.sql_expression_limit_offset + def test_expr_limit_simple_offset(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(2), + [(3, 3, 4), (4, 4, 5)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .limit(literal_column("1") + literal_column("1")) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.fetch_ties + def test_simple_fetch_ties(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.x.desc()).fetch(1, with_ties=True), + [(4, 4, 5), (5, 4, 6)], + set_=True, + ) + + self._assert_result( + connection, + select(table).order_by(table.c.x.desc()).fetch(3, with_ties=True), + [(3, 3, 4), (4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_offset_with_options + def test_fetch_offset_ties(self, connection): + table = self.tables.some_table + fa = connection.execute( + select(table) + .order_by(table.c.x) + .fetch(2, with_ties=True) + .offset(2) + ).fetchall() + eq_(fa[0], (3, 3, 4)) + eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + + @testing.requires.fetch_ties + @testing.requires.fetch_offset_with_options + def test_fetch_offset_ties_exact_number(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.x) + .fetch(2, with_ties=True) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + self._assert_result( + connection, + select(table) + .order_by(table.c.x) + .fetch(3, with_ties=True) + .offset(3), + [(4, 4, 5), (5, 4, 6)], + ) + + @testing.requires.fetch_percent + def test_simple_fetch_percent(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table).order_by(table.c.id).fetch(20, percent=True), + [(1, 1, 2)], + ) + + @testing.requires.fetch_percent + @testing.requires.fetch_offset_with_options + def test_fetch_offset_percent(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.id) + .fetch(40, percent=True) + .offset(1), + [(2, 2, 3), (3, 3, 4)], + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_percent + def test_simple_fetch_percent_ties(self, connection): + table = self.tables.some_table + self._assert_result( + connection, + select(table) + .order_by(table.c.x.desc()) + .fetch(20, percent=True, with_ties=True), + [(4, 4, 5), (5, 4, 6)], + set_=True, + ) + + @testing.requires.fetch_ties + @testing.requires.fetch_percent + @testing.requires.fetch_offset_with_options + def test_fetch_offset_percent_ties(self, connection): + table = self.tables.some_table + fa = connection.execute( + select(table) + .order_by(table.c.x) + .fetch(40, percent=True, with_ties=True) + .offset(2) + ).fetchall() + eq_(fa[0], (3, 3, 4)) + eq_(set(fa), set([(3, 3, 4), (4, 4, 5), (5, 4, 6)])) + + +class JoinTest(fixtures.TablesTest): + __backend__ = True + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + @classmethod + def define_tables(cls, metadata): + Table("a", metadata, Column("id", Integer, primary_key=True)) + Table( + "b", + metadata, + Column("id", Integer, primary_key=True), + Column("a_id", ForeignKey("a.id"), nullable=False), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.a.insert(), + [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}], + ) + + connection.execute( + cls.tables.b.insert(), + [ + {"id": 1, "a_id": 1}, + {"id": 2, "a_id": 1}, + {"id": 4, "a_id": 2}, + {"id": 5, "a_id": 3}, + ], + ) + + def test_inner_join_fk(self): + a, b = self.tables("a", "b") + + stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id) + + self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)]) + + def test_inner_join_true(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.join(b, true())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result( + stmt, + [ + (a, b, c) + for (a,), (b, c) in itertools.product( + [(1,), (2,), (3,), (4,), (5,)], + [(1, 1), (2, 1), (4, 2), (5, 3)], + ) + ], + ) + + def test_inner_join_false(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.join(b, false())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result(stmt, []) + + def test_outer_join_false(self): + a, b = self.tables("a", "b") + + stmt = ( + select(a, b) + .select_from(a.outerjoin(b, false())) + .order_by(a.c.id, b.c.id) + ) + + self._assert_result( + stmt, + [ + (1, None, None), + (2, None, None), + (3, None, None), + (4, None, None), + (5, None, None), + ], + ) + + def test_outer_join_fk(self): + a, b = self.tables("a", "b") + + stmt = select(a, b).select_from(a.join(b)).order_by(a.c.id, b.c.id) + + self._assert_result(stmt, [(1, 1, 1), (1, 2, 1), (2, 4, 2), (3, 5, 3)]) + + +class CompoundSelectTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + def test_plain_union(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_select_from_plain_union(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2) + s2 = select(table).where(table.c.id == 3) + + u1 = union(s1, s2).alias().select() + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.order_by_col_from_union + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_wo_limit_offset + def test_order_by_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).order_by(table.c.id) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_distinct_selectable_in_unions(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).distinct() + s2 = select(table).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + @testing.requires.parens_in_union_contained_select_w_limit_offset + def test_limit_offset_in_unions_from_alias(self): + table = self.tables.some_table + s1 = select(table).where(table.c.id == 2).limit(1).order_by(table.c.id) + s2 = select(table).where(table.c.id == 3).limit(1).order_by(table.c.id) + + # this necessarily has double parens + u1 = union(s1, s2).alias() + self._assert_result( + u1.select().limit(2).order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)] + ) + + def test_limit_offset_aliased_selectable_in_unions(self): + table = self.tables.some_table + s1 = ( + select(table) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select(table) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + self._assert_result( + u1.order_by(u1.selected_columns.id), [(2, 2, 3), (3, 3, 4)] + ) + + +class PostCompileParamsTest( + AssertsExecutionResults, AssertsCompiledSQL, fixtures.TablesTest +): + __backend__ = True + + __requires__ = ("standard_cursor_sql",) + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "z": "z1"}, + {"id": 2, "x": 2, "y": 3, "z": "z2"}, + {"id": 3, "x": 3, "y": 4, "z": "z3"}, + {"id": 4, "x": 4, "y": 5, "z": "z4"}, + ], + ) + + def test_compile(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", literal_execute=True) + ) + + self.assert_compile( + stmt, + "SELECT some_table.id FROM some_table " + "WHERE some_table.x = __[POSTCOMPILE_q]", + {}, + ) + + def test_compile_literal_binds(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", 10, literal_execute=True) + ) + + self.assert_compile( + stmt, + "SELECT some_table.id FROM some_table WHERE some_table.x = 10", + {}, + literal_binds=True, + ) + + def test_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x == bindparam("q", literal_execute=True) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=10)) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x = 10", + () if config.db.dialect.positional else {}, + ) + ) + + def test_execute_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + table.c.x.in_(bindparam("q", expanding=True, literal_execute=True)) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[5, 6, 7])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE some_table.x IN (5, 6, 7)", + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + tuple_(table.c.x, table.c.y).in_( + bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, 10), (12, 18)])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.y) " + "IN (%s(5, 10), (12, 18))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + @testing.requires.tuple_in + def test_execute_tuple_expanding_plus_literal_heterogeneous_execute(self): + table = self.tables.some_table + + stmt = select(table.c.id).where( + tuple_(table.c.x, table.c.z).in_( + bindparam("q", expanding=True, literal_execute=True) + ) + ) + + with self.sql_execution_asserter() as asserter: + with config.db.connect() as conn: + conn.execute(stmt, dict(q=[(5, "z1"), (12, "z3")])) + + asserter.assert_( + CursorSQL( + "SELECT some_table.id \nFROM some_table " + "\nWHERE (some_table.x, some_table.z) " + "IN (%s(5, 'z1'), (12, 'z3'))" + % ("VALUES " if config.db.dialect.tuple_in_values else ""), + () if config.db.dialect.positional else {}, + ) + ) + + +class ExpandingBoundInTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("z", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "x": 1, "y": 2, "z": "z1"}, + {"id": 2, "x": 2, "y": 3, "z": "z2"}, + {"id": 3, "x": 3, "y": 4, "z": "z3"}, + {"id": 4, "x": 4, "y": 5, "z": "z4"}, + ], + ) + + def _assert_result(self, select, result, params=()): + with config.db.connect() as conn: + eq_(conn.execute(select, params).fetchall(), result) + + def test_multiple_empty_sets_bindparam(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .where(table.c.y.in_(bindparam("p"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": [], "p": []}) + + def test_multiple_empty_sets_direct(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([])) + .where(table.c.y.in_([])) + .order_by(table.c.id) + ) + self._assert_result(stmt, []) + + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_direct(self): + table = self.tables.some_table + + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)]) + go([], []) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_direct(self): + table = self.tables.some_table + + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)]) + go([], []) + + def test_bound_in_scalar_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) + + def test_bound_in_scalar_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([2, 3, 4])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + def test_nonempty_in_plus_empty_notin(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([2, 3])) + .where(table.c.id.not_in([])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,)]) + + def test_empty_in_plus_notempty_notin(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_([])) + .where(table.c.id.not_in([2, 3])) + .order_by(table.c.id) + ) + self._assert_result(stmt, []) + + def test_typed_str_in(self): + """test related to #7292. + + as a type is given to the bound param, there is no ambiguity + to the type of element. + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", type_=String, expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + def test_untyped_str_in(self): + """test related to #7292. + + for untyped expression, we look at the types of elements. + Test for Sequence to detect tuple in. but not strings or bytes! + as always.... + + """ + + stmt = text( + "select id FROM some_table WHERE z IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": ["z2", "z3", "z4"]}, + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} + ) + + @testing.requires.tuple_in + def test_bound_in_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where( + tuple_(table.c.x, table.c.z).in_( + [(2, "z2"), (3, "z3"), (4, "z4")] + ) + ) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self): + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams( + bindparam( + "q", type_=TupleType(Integer(), String()), expanding=True + ) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now + + class LikeATuple(collections_abc.Sequence): + def __init__(self, *data): + self._data = data + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + stmt = text( + "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" + ).bindparams(bindparam("q", expanding=True)) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={ + "q": [ + LikeATuple(2, "z2"), + LikeATuple(3, "z3"), + LikeATuple(4, "z4"), + ] + }, + ) + + def test_empty_set_against_integer_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + def test_empty_set_against_integer_direct(self): + table = self.tables.some_table + stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + + def test_empty_set_against_integer_negation_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.x.not_in(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + + def test_empty_set_against_integer_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_empty_set_against_string_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.z.in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [], params={"q": []}) + + def test_empty_set_against_string_direct(self): + table = self.tables.some_table + stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + + def test_empty_set_against_string_negation_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(table.c.z.not_in(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) + + def test_empty_set_against_string_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_null_in_empty_set_is_false_bindparam(self, connection): + stmt = select( + case( + ( + null().in_(bindparam("foo", value=())), + true(), + ), + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + def test_null_in_empty_set_is_false_direct(self, connection): + stmt = select( + case( + ( + null().in_([]), + true(), + ), + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + +class LikeFunctionsTest(fixtures.TablesTest): + __backend__ = True + + run_inserts = "once" + run_deletes = None + + @classmethod + def define_tables(cls, metadata): + Table( + "some_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.some_table.insert(), + [ + {"id": 1, "data": "abcdefg"}, + {"id": 2, "data": "ab/cdefg"}, + {"id": 3, "data": "ab%cdefg"}, + {"id": 4, "data": "ab_cdefg"}, + {"id": 5, "data": "abcde/fg"}, + {"id": 6, "data": "abcde%fg"}, + {"id": 7, "data": "ab#cdefg"}, + {"id": 8, "data": "ab9cdefg"}, + {"id": 9, "data": "abcde#fg"}, + {"id": 10, "data": "abcd9fg"}, + {"id": 11, "data": None}, + ], + ) + + def _test(self, expr, expected): + some_table = self.tables.some_table + + with config.db.connect() as conn: + rows = { + value + for value, in conn.execute(select(some_table.c.id).where(expr)) + } + + eq_(rows, expected) + + def test_startswith_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c"), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + + def test_startswith_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c", autoescape=True), {3}) + + def test_startswith_sqlexpr(self): + col = self.tables.some_table.c.data + self._test( + col.startswith(literal_column("'ab%c'")), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + ) + + def test_startswith_escape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab##c", escape="#"), {7}) + + def test_startswith_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.startswith("ab%c", autoescape=True, escape="#"), {3}) + self._test(col.startswith("ab#c", autoescape=True, escape="#"), {7}) + + def test_endswith_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + + def test_endswith_sqlexpr(self): + col = self.tables.some_table.c.data + self._test( + col.endswith(literal_column("'e%fg'")), {1, 2, 3, 4, 5, 6, 7, 8, 9} + ) + + def test_endswith_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg", autoescape=True), {6}) + + def test_endswith_escape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e##fg", escape="#"), {9}) + + def test_endswith_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.endswith("e%fg", autoescape=True, escape="#"), {6}) + self._test(col.endswith("e#fg", autoescape=True, escape="#"), {9}) + + def test_contains_unescaped(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cde"), {1, 2, 3, 4, 5, 6, 7, 8, 9}) + + def test_contains_autoescape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cde", autoescape=True), {3}) + + def test_contains_escape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b##cde", escape="#"), {7}) + + def test_contains_autoescape_escape(self): + col = self.tables.some_table.c.data + self._test(col.contains("b%cd", autoescape=True, escape="#"), {3}) + self._test(col.contains("b#cd", autoescape=True, escape="#"), {7}) + + @testing.requires.regexp_match + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10}) + + @testing.requires.regexp_replace + def test_regexp_replace(self): + col = self.tables.some_table.c.data + self._test( + col.regexp_replace("a.cde", "FOO").contains("FOO"), {1, 5, 6, 9} + ) + + @testing.requires.regexp_match + @testing.combinations( + ("a.cde", {1, 5, 6, 9}), + ("abc", {1, 5, 6, 9, 10}), + ("^abc", {1, 5, 6, 9, 10}), + ("9cde", {8}), + ("^a", set(range(1, 11))), + ("(b|c)", set(range(1, 11))), + ("^(b|c)", set()), + ) + def test_regexp_match(self, text, expected): + col = self.tables.some_table.c.data + self._test(col.regexp_match(text), expected) + + +class ComputedColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("computed_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "square", + metadata, + Column("id", Integer, primary_key=True), + Column("side", Integer), + Column("area", Integer, Computed("side * side")), + Column("perimeter", Integer, Computed("4 * side")), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.square.insert(), + [{"id": 1, "side": 10}, {"id": 10, "side": 42}], + ) + + def test_select_all(self): + with config.db.connect() as conn: + res = conn.execute( + select(text("*")) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(1, 10, 100, 40), (10, 42, 1764, 168)]) + + def test_select_columns(self): + with config.db.connect() as conn: + res = conn.execute( + select( + self.tables.square.c.area, self.tables.square.c.perimeter + ) + .select_from(self.tables.square) + .order_by(self.tables.square.c.id) + ).fetchall() + eq_(res, [(100, 40), (1764, 168)]) + + +class IdentityColumnTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("identity_columns",) + run_inserts = "once" + run_deletes = "once" + + @classmethod + def define_tables(cls, metadata): + Table( + "tbl_a", + metadata, + Column( + "id", + Integer, + Identity( + always=True, start=42, nominvalue=True, nomaxvalue=True + ), + primary_key=True, + ), + Column("desc", String(100)), + ) + Table( + "tbl_b", + metadata, + Column( + "id", + Integer, + Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0), + primary_key=True, + ), + Column("desc", String(100)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.tbl_a.insert(), + [{"desc": "a"}, {"desc": "b"}], + ) + connection.execute( + cls.tables.tbl_b.insert(), + [{"desc": "a"}, {"desc": "b"}], + ) + connection.execute( + cls.tables.tbl_b.insert(), + [{"id": 42, "desc": "c"}], + ) + + def test_select_all(self, connection): + res = connection.execute( + select(text("*")) + .select_from(self.tables.tbl_a) + .order_by(self.tables.tbl_a.c.id) + ).fetchall() + eq_(res, [(42, "a"), (43, "b")]) + + res = connection.execute( + select(text("*")) + .select_from(self.tables.tbl_b) + .order_by(self.tables.tbl_b.c.id) + ).fetchall() + eq_(res, [(-5, "b"), (0, "a"), (42, "c")]) + + def test_select_columns(self, connection): + + res = connection.execute( + select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id) + ).fetchall() + eq_(res, [(42,), (43,)]) + + @testing.requires.identity_columns_standard + def test_insert_always_error(self, connection): + def fn(): + connection.execute( + self.tables.tbl_a.insert(), + [{"id": 200, "desc": "a"}], + ) + + assert_raises((DatabaseError, ProgrammingError), fn) + + +class IdentityAutoincrementTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("autoincrement_without_sequence",) + + @classmethod + def define_tables(cls, metadata): + Table( + "tbl", + metadata, + Column( + "id", + Integer, + Identity(), + primary_key=True, + autoincrement=True, + ), + Column("desc", String(100)), + ) + + def test_autoincrement_with_identity(self, connection): + res = connection.execute(self.tables.tbl.insert(), {"desc": "row"}) + res = connection.execute(self.tables.tbl.select()).first() + eq_(res, (1, "row")) + + +class ExistsTest(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "stuff", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.stuff.insert(), + [ + {"id": 1, "data": "some data"}, + {"id": 2, "data": "some data"}, + {"id": 3, "data": "some data"}, + {"id": 4, "data": "some other data"}, + ], + ) + + def test_select_exists(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select(literal(1)).where( + exists().where(stuff.c.data == "some data") + ) + ).fetchall(), + [(1,)], + ) + + def test_select_exists_false(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select(literal(1)).where( + exists().where(stuff.c.data == "no data") + ) + ).fetchall(), + [], + ) + + +class DistinctOnTest(AssertsCompiledSQL, fixtures.TablesTest): + __backend__ = True + + @testing.fails_if(testing.requires.supports_distinct_on) + def test_distinct_on(self): + stm = select("*").distinct(column("q")).select_from(table("foo")) + with testing.expect_deprecated( + "DISTINCT ON is currently supported only by the PostgreSQL " + ): + self.assert_compile(stm, "SELECT DISTINCT * FROM foo") + + +class IsOrIsNotDistinctFromTest(fixtures.TablesTest): + __backend__ = True + __requires__ = ("supports_is_distinct_from",) + + @classmethod + def define_tables(cls, metadata): + Table( + "is_distinct_test", + metadata, + Column("id", Integer, primary_key=True), + Column("col_a", Integer, nullable=True), + Column("col_b", Integer, nullable=True), + ) + + @testing.combinations( + ("both_int_different", 0, 1, 1), + ("both_int_same", 1, 1, 0), + ("one_null_first", None, 1, 1), + ("one_null_second", 0, None, 1), + ("both_null", None, None, 0), + id_="iaaa", + argnames="col_a_value, col_b_value, expected_row_count_for_is", + ) + def test_is_or_is_not_distinct_from( + self, col_a_value, col_b_value, expected_row_count_for_is, connection + ): + tbl = self.tables.is_distinct_test + + connection.execute( + tbl.insert(), + [{"id": 1, "col_a": col_a_value, "col_b": col_b_value}], + ) + + result = connection.execute( + tbl.select().where(tbl.c.col_a.is_distinct_from(tbl.c.col_b)) + ).fetchall() + eq_( + len(result), + expected_row_count_for_is, + ) + + expected_row_count_for_is_not = ( + 1 if expected_row_count_for_is == 0 else 0 + ) + result = connection.execute( + tbl.select().where(tbl.c.col_a.is_not_distinct_from(tbl.c.col_b)) + ).fetchall() + eq_( + len(result), + expected_row_count_for_is_not, + ) diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py new file mode 100644 index 0000000..d6747d2 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -0,0 +1,282 @@ +from .. import config +from .. import fixtures +from ..assertions import eq_ +from ..assertions import is_true +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import inspect +from ... import Integer +from ... import MetaData +from ... import Sequence +from ... import String +from ... import testing + + +class SequenceTest(fixtures.TablesTest): + __requires__ = ("sequences",) + __backend__ = True + + run_create_tables = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "seq_pk", + metadata, + Column( + "id", + Integer, + Sequence("tab_id_seq"), + primary_key=True, + ), + Column("data", String(50)), + ) + + Table( + "seq_opt_pk", + metadata, + Column( + "id", + Integer, + Sequence("tab_id_seq", data_type=Integer, optional=True), + primary_key=True, + ), + Column("data", String(50)), + ) + + Table( + "seq_no_returning", + metadata, + Column( + "id", + Integer, + Sequence("noret_id_seq"), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + ) + + if testing.requires.schemas.enabled: + Table( + "seq_no_returning_sch", + metadata, + Column( + "id", + Integer, + Sequence("noret_sch_id_seq", schema=config.test_schema), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=False, + schema=config.test_schema, + ) + + def test_insert_roundtrip(self, connection): + connection.execute(self.tables.seq_pk.insert(), dict(data="some data")) + self._assert_round_trip(self.tables.seq_pk, connection) + + def test_insert_lastrowid(self, connection): + r = connection.execute( + self.tables.seq_pk.insert(), dict(data="some data") + ) + eq_( + r.inserted_primary_key, (testing.db.dialect.default_sequence_base,) + ) + + def test_nextval_direct(self, connection): + r = connection.execute(self.tables.seq_pk.c.id.default) + eq_(r, testing.db.dialect.default_sequence_base) + + @requirements.sequences_optional + def test_optional_seq(self, connection): + r = connection.execute( + self.tables.seq_opt_pk.insert(), dict(data="some data") + ) + eq_(r.inserted_primary_key, (1,)) + + def _assert_round_trip(self, table, conn): + row = conn.execute(table.select()).first() + eq_(row, (testing.db.dialect.default_sequence_base, "some data")) + + def test_insert_roundtrip_no_implicit_returning(self, connection): + connection.execute( + self.tables.seq_no_returning.insert(), dict(data="some data") + ) + self._assert_round_trip(self.tables.seq_no_returning, connection) + + @testing.combinations((True,), (False,), argnames="implicit_returning") + @testing.requires.schemas + def test_insert_roundtrip_translate(self, connection, implicit_returning): + + seq_no_returning = Table( + "seq_no_returning_sch", + MetaData(), + Column( + "id", + Integer, + Sequence("noret_sch_id_seq", schema="alt_schema"), + primary_key=True, + ), + Column("data", String(50)), + implicit_returning=implicit_returning, + schema="alt_schema", + ) + + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + connection.execute(seq_no_returning.insert(), dict(data="some data")) + self._assert_round_trip(seq_no_returning, connection) + + @testing.requires.schemas + def test_nextval_direct_schema_translate(self, connection): + seq = Sequence("noret_sch_id_seq", schema="alt_schema") + connection = connection.execution_options( + schema_translate_map={"alt_schema": config.test_schema} + ) + + r = connection.execute(seq) + eq_(r, testing.db.dialect.default_sequence_base) + + +class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_literal_binds_inline_compile(self, connection): + table = Table( + "x", + MetaData(), + Column("y", Integer, Sequence("y_seq")), + Column("q", Integer), + ) + + stmt = table.insert().values(q=5) + + seq_nextval = connection.dialect.statement_compiler( + statement=None, dialect=connection.dialect + ).visit_sequence(Sequence("y_seq")) + self.assert_compile( + stmt, + "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval,), + literal_binds=True, + dialect=connection.dialect, + ) + + +class HasSequenceTest(fixtures.TablesTest): + run_deletes = None + + __requires__ = ("sequences",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + Sequence( + "other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True + ) + if testing.requires.schemas.enabled: + Sequence( + "user_id_seq", schema=config.test_schema, metadata=metadata + ) + Sequence( + "schema_seq", schema=config.test_schema, metadata=metadata + ) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + def test_has_sequence(self, connection): + eq_( + inspect(connection).has_sequence("user_id_seq"), + True, + ) + + def test_has_sequence_other_object(self, connection): + eq_( + inspect(connection).has_sequence("user_id_table"), + False, + ) + + @testing.requires.schemas + def test_has_sequence_schema(self, connection): + eq_( + inspect(connection).has_sequence( + "user_id_seq", schema=config.test_schema + ), + True, + ) + + def test_has_sequence_neg(self, connection): + eq_( + inspect(connection).has_sequence("some_sequence"), + False, + ) + + @testing.requires.schemas + def test_has_sequence_schemas_neg(self, connection): + eq_( + inspect(connection).has_sequence( + "some_sequence", schema=config.test_schema + ), + False, + ) + + @testing.requires.schemas + def test_has_sequence_default_not_in_remote(self, connection): + eq_( + inspect(connection).has_sequence( + "other_sequence", schema=config.test_schema + ), + False, + ) + + @testing.requires.schemas + def test_has_sequence_remote_not_in_default(self, connection): + eq_( + inspect(connection).has_sequence("schema_seq"), + False, + ) + + def test_get_sequence_names(self, connection): + exp = {"other_seq", "user_id_seq"} + + res = set(inspect(connection).get_sequence_names()) + is_true(res.intersection(exp) == exp) + is_true("schema_seq" not in res) + + @testing.requires.schemas + def test_get_sequence_names_no_sequence_schema(self, connection): + eq_( + inspect(connection).get_sequence_names( + schema=config.test_schema_2 + ), + [], + ) + + @testing.requires.schemas + def test_get_sequence_names_sequences_schema(self, connection): + eq_( + sorted( + inspect(connection).get_sequence_names( + schema=config.test_schema + ) + ), + ["schema_seq", "user_id_seq"], + ) + + +class HasSequenceTestEmpty(fixtures.TestBase): + __requires__ = ("sequences",) + __backend__ = True + + def test_get_sequence_names_no_sequence(self, connection): + eq_( + inspect(connection).get_sequence_names(), + [], + ) diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py new file mode 100644 index 0000000..b96350e --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -0,0 +1,1508 @@ +# coding: utf-8 + +import datetime +import decimal +import json +import re + +from .. import config +from .. import engines +from .. import fixtures +from .. import mock +from ..assertions import eq_ +from ..assertions import is_ +from ..config import requirements +from ..schema import Column +from ..schema import Table +from ... import and_ +from ... import BigInteger +from ... import bindparam +from ... import Boolean +from ... import case +from ... import cast +from ... import Date +from ... import DateTime +from ... import Float +from ... import Integer +from ... import JSON +from ... import literal +from ... import MetaData +from ... import null +from ... import Numeric +from ... import select +from ... import String +from ... import testing +from ... import Text +from ... import Time +from ... import TIMESTAMP +from ... import TypeDecorator +from ... import Unicode +from ... import UnicodeText +from ... import util +from ...orm import declarative_base +from ...orm import Session +from ...sql.sqltypes import LargeBinary +from ...sql.sqltypes import PickleType +from ...util import compat +from ...util import u + + +class _LiteralRoundTripFixture(object): + supports_whereclause = True + + @testing.fixture + def literal_round_trip(self, metadata, connection): + """test literal rendering""" + + # for literal, we test the literal render in an INSERT + # into a typed column. we can then SELECT it back as its + # official type; ideally we'd be able to use CAST here + # but MySQL in particular can't CAST fully + + def run(type_, input_, output, filter_=None): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + for value in input_: + ins = ( + t.insert() + .values(x=literal(value, type_)) + .compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) + ) + connection.execute(ins) + + if self.supports_whereclause: + stmt = t.select().where(t.c.x == literal(value)) + else: + stmt = t.select() + + stmt = stmt.compile( + dialect=testing.db.dialect, + compile_kwargs=dict(literal_binds=True), + ) + for row in connection.execute(stmt): + value = row[0] + if filter_ is not None: + value = filter_(value) + assert value in output + + return run + + +class _UnicodeFixture(_LiteralRoundTripFixture, fixtures.TestBase): + __requires__ = ("unicode_data",) + + data = u( + "Alors vous imaginez ma 🐍 surprise, au lever du jour, " + "quand une drôle de petite 🐍 voix m’a réveillé. Elle " + "disait: « S’il vous plaît… dessine-moi 🐍 un mouton! »" + ) + + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + + @classmethod + def define_tables(cls, metadata): + Table( + "unicode_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("unicode_data", cls.datatype), + ) + + def test_round_trip(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": self.data} + ) + + row = connection.execute(select(unicode_table.c.unicode_data)).first() + + eq_(row, (self.data,)) + assert isinstance(row[0], util.text_type) + + def test_round_trip_executemany(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), + [{"id": i, "unicode_data": self.data} for i in range(1, 4)], + ) + + rows = connection.execute( + select(unicode_table.c.unicode_data) + ).fetchall() + eq_(rows, [(self.data,) for i in range(1, 4)]) + for row in rows: + assert isinstance(row[0], util.text_type) + + def _test_null_strings(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": None} + ) + row = connection.execute(select(unicode_table.c.unicode_data)).first() + eq_(row, (None,)) + + def _test_empty_strings(self, connection): + unicode_table = self.tables.unicode_table + + connection.execute( + unicode_table.insert(), {"id": 1, "unicode_data": u("")} + ) + row = connection.execute(select(unicode_table.c.unicode_data)).first() + eq_(row, (u(""),)) + + def test_literal(self, literal_round_trip): + literal_round_trip(self.datatype, [self.data], [self.data]) + + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( + self.datatype, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + + +class UnicodeVarcharTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = ("unicode_data",) + __backend__ = True + + datatype = Unicode(255) + + @requirements.empty_strings_varchar + def test_empty_strings_varchar(self, connection): + self._test_empty_strings(connection) + + def test_null_strings_varchar(self, connection): + self._test_null_strings(connection) + + +class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): + __requires__ = "unicode_data", "text_type" + __backend__ = True + + datatype = UnicodeText() + + @requirements.empty_strings_text + def test_empty_strings_text(self, connection): + self._test_empty_strings(connection) + + def test_null_strings_text(self, connection): + self._test_null_strings(connection) + + +class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("binary_literals",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "binary_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("binary_data", LargeBinary), + Column("pickle_data", PickleType), + ) + + def test_binary_roundtrip(self, connection): + binary_table = self.tables.binary_table + + connection.execute( + binary_table.insert(), {"id": 1, "binary_data": b"this is binary"} + ) + row = connection.execute(select(binary_table.c.binary_data)).first() + eq_(row, (b"this is binary",)) + + def test_pickle_roundtrip(self, connection): + binary_table = self.tables.binary_table + + connection.execute( + binary_table.insert(), + {"id": 1, "pickle_data": {"foo": [1, 2, 3], "bar": "bat"}}, + ) + row = connection.execute(select(binary_table.c.pickle_data)).first() + eq_(row, ({"foo": [1, 2, 3], "bar": "bat"},)) + + +class TextTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("text_type",) + __backend__ = True + + @property + def supports_whereclause(self): + return config.requirements.expressions_against_unbounded_text.enabled + + @classmethod + def define_tables(cls, metadata): + Table( + "text_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("text_data", Text), + ) + + def test_text_roundtrip(self, connection): + text_table = self.tables.text_table + + connection.execute( + text_table.insert(), {"id": 1, "text_data": "some text"} + ) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, ("some text",)) + + @testing.requires.empty_strings_text + def test_text_empty_strings(self, connection): + text_table = self.tables.text_table + + connection.execute(text_table.insert(), {"id": 1, "text_data": ""}) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, ("",)) + + def test_text_null_strings(self, connection): + text_table = self.tables.text_table + + connection.execute(text_table.insert(), {"id": 1, "text_data": None}) + row = connection.execute(select(text_table.c.text_data)).first() + eq_(row, (None,)) + + def test_literal(self, literal_round_trip): + literal_round_trip(Text, ["some text"], ["some text"]) + + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( + Text, [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + + def test_literal_quoting(self, literal_round_trip): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip(Text, [data], [data]) + + def test_literal_backslashes(self, literal_round_trip): + data = r"backslash one \ backslash two \\ end" + literal_round_trip(Text, [data], [data]) + + def test_literal_percentsigns(self, literal_round_trip): + data = r"percent % signs %% percent" + literal_round_trip(Text, [data], [data]) + + +class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @requirements.unbounded_varchar + def test_nolength_string(self): + metadata = MetaData() + foo = Table("foo", metadata, Column("one", String)) + + foo.create(config.db) + foo.drop(config.db) + + def test_literal(self, literal_round_trip): + # note that in Python 3, this invokes the Unicode + # datatype for the literal part because all strings are unicode + literal_round_trip(String(40), ["some text"], ["some text"]) + + def test_literal_non_ascii(self, literal_round_trip): + literal_round_trip( + String(40), [util.u("réve🐍 illé")], [util.u("réve🐍 illé")] + ) + + def test_literal_quoting(self, literal_round_trip): + data = """some 'text' hey "hi there" that's text""" + literal_round_trip(String(40), [data], [data]) + + def test_literal_backslashes(self, literal_round_trip): + data = r"backslash one \ backslash two \\ end" + literal_round_trip(String(40), [data], [data]) + + +class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): + compare = None + + @classmethod + def define_tables(cls, metadata): + class Decorated(TypeDecorator): + impl = cls.datatype + cache_ok = True + + Table( + "date_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("date_data", cls.datatype), + Column("decorated_date_data", Decorated), + ) + + @testing.requires.datetime_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + def test_round_trip(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "date_data": self.data} + ) + + row = connection.execute(select(date_table.c.date_data)).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + + def test_round_trip_decorated(self, connection): + date_table = self.tables.date_table + + connection.execute( + date_table.insert(), {"id": 1, "decorated_date_data": self.data} + ) + + row = connection.execute( + select(date_table.c.decorated_date_data) + ).first() + + compare = self.compare or self.data + eq_(row, (compare,)) + assert isinstance(row[0], type(compare)) + + def test_null(self, connection): + date_table = self.tables.date_table + + connection.execute(date_table.insert(), {"id": 1, "date_data": None}) + + row = connection.execute(select(date_table.c.date_data)).first() + eq_(row, (None,)) + + @testing.requires.datetime_literals + def test_literal(self, literal_round_trip): + compare = self.compare or self.data + literal_round_trip(self.datatype, [self.data], [compare]) + + @testing.requires.standalone_null_binds_whereclause + def test_null_bound_comparison(self): + # this test is based on an Oracle issue observed in #4886. + # passing NULL for an expression that needs to be interpreted as + # a certain type, does the DBAPI have the info it needs to do this. + date_table = self.tables.date_table + with config.db.begin() as conn: + result = conn.execute( + date_table.insert(), {"id": 1, "date_data": self.data} + ) + id_ = result.inserted_primary_key[0] + stmt = select(date_table.c.id).where( + case( + ( + bindparam("foo", type_=self.datatype) != None, + bindparam("foo", type_=self.datatype), + ), + else_=date_table.c.date_data, + ) + == date_table.c.date_data + ) + + row = conn.execute(stmt, {"foo": None}).first() + eq_(row[0], id_) + + +class DateTimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + + +class DateTimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_timezone",) + __backend__ = True + datatype = DateTime(timezone=True) + data = datetime.datetime( + 2012, 10, 15, 12, 57, 18, tzinfo=compat.timezone.utc + ) + + +class DateTimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_microseconds",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + + +class TimestampMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("timestamp_microseconds",) + __backend__ = True + datatype = TIMESTAMP + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396) + + @testing.requires.timestamp_microseconds_implicit_bound + def test_select_direct(self, connection): + result = connection.scalar(select(literal(self.data))) + eq_(result, self.data) + + +class TimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time",) + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18) + + +class TimeTZTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_timezone",) + __backend__ = True + datatype = Time(timezone=True) + data = datetime.time(12, 57, 18, tzinfo=compat.timezone.utc) + + +class TimeMicrosecondsTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("time_microseconds",) + __backend__ = True + datatype = Time + data = datetime.time(12, 57, 18, 396) + + +class DateTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("date",) + __backend__ = True + datatype = Date + data = datetime.date(2012, 10, 15) + + +class DateTimeCoercedToDateTimeTest(_DateFixture, fixtures.TablesTest): + __requires__ = "date", "date_coerces_from_datetime" + __backend__ = True + datatype = Date + data = datetime.datetime(2012, 10, 15, 12, 57, 18) + compare = datetime.date(2012, 10, 15) + + +class DateTimeHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("datetime_historic",) + __backend__ = True + datatype = DateTime + data = datetime.datetime(1850, 11, 10, 11, 52, 35) + + +class DateHistoricTest(_DateFixture, fixtures.TablesTest): + __requires__ = ("date_historic",) + __backend__ = True + datatype = Date + data = datetime.date(1727, 4, 1) + + +class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + def test_literal(self, literal_round_trip): + literal_round_trip(Integer, [5], [5]) + + def test_huge_int(self, integer_round_trip): + integer_round_trip(BigInteger, 1376537018368127) + + @testing.fixture + def integer_round_trip(self, metadata, connection): + def run(datatype, data): + int_table = Table( + "integer_table", + metadata, + Column( + "id", + Integer, + primary_key=True, + test_needs_autoincrement=True, + ), + Column("integer_data", datatype), + ) + + metadata.create_all(config.db) + + connection.execute( + int_table.insert(), {"id": 1, "integer_data": data} + ) + + row = connection.execute(select(int_table.c.integer_data)).first() + + eq_(row, (data,)) + + if util.py3k: + assert isinstance(row[0], int) + else: + assert isinstance(row[0], (long, int)) # noqa + + return run + + +class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def string_as_int(self): + class StringAsInt(TypeDecorator): + impl = String(50) + cache_ok = True + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + def column_expression(self, col): + return cast(col, Integer) + + def bind_expression(self, col): + return cast(col, String(50)) + + return StringAsInt() + + def test_special_type(self, metadata, connection, string_as_int): + + type_ = string_as_int + + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]]) + + result = {row[0] for row in connection.execute(t.select())} + eq_(result, {1, 2, 3}) + + result = { + row[0] for row in connection.execute(t.select().where(t.c.x == 2)) + } + eq_(result, {2}) + + +class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def do_numeric_test(self, metadata, connection): + @testing.emits_warning( + r".*does \*not\* support Decimal objects natively" + ) + def run(type_, input_, output, filter_=None, check_scale=False): + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + connection.execute(t.insert(), [{"x": x} for x in input_]) + + result = {row[0] for row in connection.execute(t.select())} + output = set(output) + if filter_: + result = set(filter_(x) for x in result) + output = set(filter_(x) for x in output) + eq_(result, output) + if check_scale: + eq_([str(x) for x in result], [str(x) for x in output]) + + return run + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_render_literal_numeric(self, literal_round_trip): + literal_round_trip( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_render_literal_numeric_asfloat(self, literal_round_trip): + literal_round_trip( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + def test_render_literal_float(self, literal_round_trip): + literal_round_trip( + Float(4), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, + ) + + @testing.requires.precision_generic_float_type + def test_float_custom_scale(self, do_numeric_test): + do_numeric_test( + Float(None, decimal_return_scale=7, asdecimal=True), + [15.7563827, decimal.Decimal("15.7563827")], + [decimal.Decimal("15.7563827")], + check_scale=True, + ) + + def test_numeric_as_decimal(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4), + [15.7563, decimal.Decimal("15.7563")], + [decimal.Decimal("15.7563")], + ) + + def test_numeric_as_float(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + ) + + @testing.requires.infinity_floats + def test_infinity_floats(self, do_numeric_test): + """test for #977, #7283""" + + do_numeric_test( + Float(None), + [float("inf")], + [float("inf")], + ) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_decimal(self, do_numeric_test): + do_numeric_test(Numeric(precision=8, scale=4), [None], [None]) + + @testing.requires.fetch_null_from_numeric + def test_numeric_null_as_float(self, do_numeric_test): + do_numeric_test( + Numeric(precision=8, scale=4, asdecimal=False), [None], [None] + ) + + @testing.requires.floats_to_four_decimals + def test_float_as_decimal(self, do_numeric_test): + do_numeric_test( + Float(precision=8, asdecimal=True), + [15.7563, decimal.Decimal("15.7563"), None], + [decimal.Decimal("15.7563"), None], + filter_=lambda n: n is not None and round(n, 4) or None, + ) + + def test_float_as_float(self, do_numeric_test): + do_numeric_test( + Float(precision=8), + [15.7563, decimal.Decimal("15.7563")], + [15.7563], + filter_=lambda n: n is not None and round(n, 5) or None, + ) + + def test_float_coerce_round_trip(self, connection): + expr = 15.7563 + + val = connection.scalar(select(literal(expr))) + eq_(val, expr) + + # this does not work in MySQL, see #4036, however we choose not + # to render CAST unconditionally since this is kind of an edge case. + + @testing.requires.implicit_decimal_binds + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_decimal_coerce_round_trip(self, connection): + expr = decimal.Decimal("15.7563") + + val = connection.scalar(select(literal(expr))) + eq_(val, expr) + + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_decimal_coerce_round_trip_w_cast(self, connection): + expr = decimal.Decimal("15.7563") + + val = connection.scalar(select(cast(expr, Numeric(10, 4)))) + eq_(val, expr) + + @testing.requires.precision_numerics_general + def test_precision_decimal(self, do_numeric_test): + numbers = set( + [ + decimal.Decimal("54.234246451650"), + decimal.Decimal("0.004354"), + decimal.Decimal("900.0"), + ] + ) + + do_numeric_test(Numeric(precision=18, scale=12), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal(self, do_numeric_test): + """test exceedingly small decimals. + + Decimal reports values with E notation when the exponent + is greater than 6. + + """ + + numbers = set( + [ + decimal.Decimal("1E-2"), + decimal.Decimal("1E-3"), + decimal.Decimal("1E-4"), + decimal.Decimal("1E-5"), + decimal.Decimal("1E-6"), + decimal.Decimal("1E-7"), + decimal.Decimal("1E-8"), + decimal.Decimal("0.01000005940696"), + decimal.Decimal("0.00000005940696"), + decimal.Decimal("0.00000000000696"), + decimal.Decimal("0.70000000000696"), + decimal.Decimal("696E-12"), + ] + ) + do_numeric_test(Numeric(precision=18, scale=14), numbers, numbers) + + @testing.requires.precision_numerics_enotation_large + def test_enotation_decimal_large(self, do_numeric_test): + """test exceedingly large decimals.""" + + numbers = set( + [ + decimal.Decimal("4E+8"), + decimal.Decimal("5748E+15"), + decimal.Decimal("1.521E+15"), + decimal.Decimal("00000000000000.1E+12"), + ] + ) + do_numeric_test(Numeric(precision=25, scale=2), numbers, numbers) + + @testing.requires.precision_numerics_many_significant_digits + def test_many_significant_digits(self, do_numeric_test): + numbers = set( + [ + decimal.Decimal("31943874831932418390.01"), + decimal.Decimal("319438950232418390.273596"), + decimal.Decimal("87673.594069654243"), + ] + ) + do_numeric_test(Numeric(precision=38, scale=12), numbers, numbers) + + @testing.requires.precision_numerics_retains_significant_digits + def test_numeric_no_decimal(self, do_numeric_test): + numbers = set([decimal.Decimal("1.000")]) + do_numeric_test( + Numeric(precision=5, scale=3), numbers, numbers, check_scale=True + ) + + +class BooleanTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "boolean_table", + metadata, + Column("id", Integer, primary_key=True, autoincrement=False), + Column("value", Boolean), + Column("unconstrained_value", Boolean(create_constraint=False)), + ) + + def test_render_literal_bool(self, literal_round_trip): + literal_round_trip(Boolean(), [True, False], [True, False]) + + def test_round_trip(self, connection): + boolean_table = self.tables.boolean_table + + connection.execute( + boolean_table.insert(), + {"id": 1, "value": True, "unconstrained_value": False}, + ) + + row = connection.execute( + select(boolean_table.c.value, boolean_table.c.unconstrained_value) + ).first() + + eq_(row, (True, False)) + assert isinstance(row[0], bool) + + @testing.requires.nullable_booleans + def test_null(self, connection): + boolean_table = self.tables.boolean_table + + connection.execute( + boolean_table.insert(), + {"id": 1, "value": None, "unconstrained_value": None}, + ) + + row = connection.execute( + select(boolean_table.c.value, boolean_table.c.unconstrained_value) + ).first() + + eq_(row, (None, None)) + + def test_whereclause(self): + # testing "WHERE <column>" renders a compatible expression + boolean_table = self.tables.boolean_table + + with config.db.begin() as conn: + conn.execute( + boolean_table.insert(), + [ + {"id": 1, "value": True, "unconstrained_value": True}, + {"id": 2, "value": False, "unconstrained_value": False}, + ], + ) + + eq_( + conn.scalar( + select(boolean_table.c.id).where(boolean_table.c.value) + ), + 1, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where( + boolean_table.c.unconstrained_value + ) + ), + 1, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where(~boolean_table.c.value) + ), + 2, + ) + eq_( + conn.scalar( + select(boolean_table.c.id).where( + ~boolean_table.c.unconstrained_value + ) + ), + 2, + ) + + +class JSONTest(_LiteralRoundTripFixture, fixtures.TablesTest): + __requires__ = ("json_type",) + __backend__ = True + + datatype = JSON + + @classmethod + def define_tables(cls, metadata): + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype, nullable=False), + Column("nulldata", cls.datatype(none_as_null=True)), + ) + + def test_round_trip_data1(self, connection): + self._test_round_trip({"key1": "value1", "key2": "value2"}, connection) + + def _test_round_trip(self, data_element, connection): + data_table = self.tables.data_table + + connection.execute( + data_table.insert(), + {"id": 1, "name": "row1", "data": data_element}, + ) + + row = connection.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + + def _index_fixtures(include_comparison): + + if include_comparison: + # basically SQL Server and MariaDB can kind of do json + # comparison, MySQL, PG and SQLite can't. not worth it. + json_elements = [] + else: + json_elements = [ + ("json", {"foo": "bar"}), + ("json", ["one", "two", "three"]), + (None, {"foo": "bar"}), + (None, ["one", "two", "three"]), + ] + + elements = [ + ("boolean", True), + ("boolean", False), + ("boolean", None), + ("string", "some string"), + ("string", None), + ("string", util.u("réve illé")), + ( + "string", + util.u("réve🐍 illé"), + testing.requires.json_index_supplementary_unicode_element, + ), + ("integer", 15), + ("integer", 1), + ("integer", 0), + ("integer", None), + ("float", 28.5), + ("float", None), + ( + "float", + 1234567.89, + ), + ("numeric", 1234567.89), + # this one "works" because the float value you see here is + # lost immediately to floating point stuff + ("numeric", 99998969694839.983485848, requirements.python3), + ("numeric", 99939.983485848, requirements.python3), + ("_decimal", decimal.Decimal("1234567.89")), + ( + "_decimal", + decimal.Decimal("99998969694839.983485848"), + # fails on SQLite and MySQL (non-mariadb) + requirements.cast_precision_numerics_many_significant_digits, + ), + ( + "_decimal", + decimal.Decimal("99939.983485848"), + ), + ] + json_elements + + def decorate(fn): + fn = testing.combinations(id_="sa", *elements)(fn) + + return fn + + return decorate + + def _json_value_insert(self, connection, datatype, value, data_element): + data_table = self.tables.data_table + if datatype == "_decimal": + + # Python's builtin json serializer basically doesn't support + # Decimal objects without implicit float conversion period. + # users can otherwise use simplejson which supports + # precision decimals + + # https://bugs.python.org/issue16535 + + # inserting as strings to avoid a new fixture around the + # dialect which would have idiosyncrasies for different + # backends. + + class DecimalEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, decimal.Decimal): + return str(o) + return super(DecimalEncoder, self).default(o) + + json_data = json.dumps(data_element, cls=DecimalEncoder) + + # take the quotes out. yup, there is *literally* no other + # way to get Python's json.dumps() to put all the digits in + # the string + json_data = re.sub(r'"(%s)"' % str(value), str(value), json_data) + + datatype = "numeric" + + connection.execute( + data_table.insert().values( + name="row1", + # to pass the string directly to every backend, including + # PostgreSQL which needs the value to be CAST as JSON + # both in the SQL as well as at the prepared statement + # level for asyncpg, while at the same time MySQL + # doesn't even support CAST for JSON, here we are + # sending the string embedded in the SQL without using + # a parameter. + data=bindparam(None, json_data, literal_execute=True), + nulldata=bindparam(None, json_data, literal_execute=True), + ), + ) + else: + connection.execute( + data_table.insert(), + { + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + p_s = None + + if datatype: + if datatype == "numeric": + a, b = str(value).split(".") + s = len(b) + p = len(a) + s + + if isinstance(value, decimal.Decimal): + compare_value = value + else: + compare_value = decimal.Decimal(str(value)) + + p_s = (p, s) + else: + compare_value = value + else: + compare_value = value + + return datatype, compare_value, p_s + + @_index_fixtures(False) + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_index_typed_access(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + + with config.db.begin() as conn: + + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data["key1"] + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + roundtrip = conn.scalar(select(expr)) + eq_(roundtrip, compare_value) + if util.py3k: # skip py2k to avoid comparing unicode to str etc. + is_(type(roundtrip), type(compare_value)) + + @_index_fixtures(True) + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_index_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + + with config.db.begin() as conn: + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data["key1"] + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute( + select(expr).where(expr == compare_value) + ).first() + + # make sure we get a row even if value is None + eq_(row, (compare_value,)) + + @_index_fixtures(True) + @testing.emits_warning(r".*does \*not\* support Decimal objects natively") + def test_path_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": {"subkey1": value}} + with config.db.begin() as conn: + + datatype, compare_value, p_s = self._json_value_insert( + conn, datatype, value, data_element + ) + + expr = data_table.c.data[("key1", "subkey1")] + + if datatype: + if datatype == "numeric" and p_s: + expr = expr.as_numeric(*p_s) + else: + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute( + select(expr).where(expr == compare_value) + ).first() + + # make sure we get a row even if value is None + eq_(row, (compare_value,)) + + @testing.combinations( + (True,), + (False,), + (None,), + (15,), + (0,), + (-1,), + (-1.0,), + (15.052,), + ("a string",), + (util.u("réve illé"),), + (util.u("réve🐍 illé"),), + ) + def test_single_element_round_trip(self, element): + data_table = self.tables.data_table + data_element = element + with config.db.begin() as conn: + conn.execute( + data_table.insert(), + { + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + row = conn.execute( + select(data_table.c.data, data_table.c.nulldata) + ).first() + + eq_(row, (data_element, data_element)) + + def test_round_trip_custom_json(self): + data_table = self.tables.data_table + data_element = {"key1": "data1"} + + js = mock.Mock(side_effect=json.dumps) + jd = mock.Mock(side_effect=json.loads) + engine = engines.testing_engine( + options=dict(json_serializer=js, json_deserializer=jd) + ) + + # support sqlite :memory: database... + data_table.create(engine, checkfirst=True) + with engine.begin() as conn: + conn.execute( + data_table.insert(), {"name": "row1", "data": data_element} + ) + row = conn.execute(select(data_table.c.data)).first() + + eq_(row, (data_element,)) + eq_(js.mock_calls, [mock.call(data_element)]) + eq_(jd.mock_calls, [mock.call(json.dumps(data_element))]) + + @testing.combinations( + ("parameters",), + ("multiparameters",), + ("values",), + ("omit",), + argnames="insert_type", + ) + def test_round_trip_none_as_sql_null(self, connection, insert_type): + col = self.tables.data_table.c["nulldata"] + + conn = connection + + if insert_type == "parameters": + stmt, params = self.tables.data_table.insert(), { + "name": "r1", + "nulldata": None, + "data": None, + } + elif insert_type == "multiparameters": + stmt, params = self.tables.data_table.insert(), [ + {"name": "r1", "nulldata": None, "data": None} + ] + elif insert_type == "values": + stmt, params = ( + self.tables.data_table.insert().values( + name="r1", + nulldata=None, + data=None, + ), + {}, + ) + elif insert_type == "omit": + stmt, params = ( + self.tables.data_table.insert(), + {"name": "r1", "data": None}, + ) + + else: + assert False + + conn.execute(stmt, params) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where(col.is_(null())) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + def test_round_trip_json_null_as_json_null(self, connection): + col = self.tables.data_table.c["data"] + + conn = connection + conn.execute( + self.tables.data_table.insert(), + {"name": "r1", "data": JSON.NULL}, + ) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where( + cast(col, String) == "null" + ) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + @testing.combinations( + ("parameters",), + ("multiparameters",), + ("values",), + argnames="insert_type", + ) + def test_round_trip_none_as_json_null(self, connection, insert_type): + col = self.tables.data_table.c["data"] + + if insert_type == "parameters": + stmt, params = self.tables.data_table.insert(), { + "name": "r1", + "data": None, + } + elif insert_type == "multiparameters": + stmt, params = self.tables.data_table.insert(), [ + {"name": "r1", "data": None} + ] + elif insert_type == "values": + stmt, params = ( + self.tables.data_table.insert().values(name="r1", data=None), + {}, + ) + else: + assert False + + conn = connection + conn.execute(stmt, params) + + eq_( + conn.scalar( + select(self.tables.data_table.c.name).where( + cast(col, String) == "null" + ) + ), + "r1", + ) + + eq_(conn.scalar(select(col)), None) + + def test_unicode_round_trip(self): + # note we include Unicode supplementary characters as well + with config.db.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + { + "name": "r1", + "data": { + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, + }, + }, + ) + + eq_( + conn.scalar(select(self.tables.data_table.c.data)), + { + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, + }, + ) + + def test_eval_none_flag_orm(self, connection): + + Base = declarative_base() + + class Data(Base): + __table__ = self.tables.data_table + + with Session(connection) as s: + d1 = Data(name="d1", data=None, nulldata=None) + s.add(d1) + s.commit() + + s.bulk_insert_mappings( + Data, [{"name": "d2", "data": None, "nulldata": None}] + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d1") + .first(), + ("null", None), + ) + eq_( + s.query( + cast(self.tables.data_table.c.data, String()), + cast(self.tables.data_table.c.nulldata, String), + ) + .filter(self.tables.data_table.c.name == "d2") + .first(), + ("null", None), + ) + + +class JSONLegacyStringCastIndexTest( + _LiteralRoundTripFixture, fixtures.TablesTest +): + """test JSON index access with "cast to string", which we have documented + for a long time as how to compare JSON values, but is ultimately not + reliable in all cases. The "as_XYZ()" comparators should be used + instead. + + """ + + __requires__ = ("json_type", "legacy_unconditional_json_extract") + __backend__ = True + + datatype = JSON + + data1 = {"key1": "value1", "key2": "value2"} + + data2 = { + "Key 'One'": "value1", + "key two": "value2", + "key three": "value ' three '", + } + + data3 = { + "key1": [1, 2, 3], + "key2": ["one", "two", "three"], + "key3": [{"four": "five"}, {"six": "seven"}], + } + + data4 = ["one", "two", "three"] + + data5 = { + "nested": { + "elem1": [{"a": "b", "c": "d"}, {"e": "f", "g": "h"}], + "elem2": {"elem3": {"elem4": "elem5"}}, + } + } + + data6 = {"a": 5, "b": "some value", "c": {"foo": "bar"}} + + @classmethod + def define_tables(cls, metadata): + Table( + "data_table", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(30), nullable=False), + Column("data", cls.datatype), + Column("nulldata", cls.datatype(none_as_null=True)), + ) + + def _criteria_fixture(self): + with config.db.begin() as conn: + conn.execute( + self.tables.data_table.insert(), + [ + {"name": "r1", "data": self.data1}, + {"name": "r2", "data": self.data2}, + {"name": "r3", "data": self.data3}, + {"name": "r4", "data": self.data4}, + {"name": "r5", "data": self.data5}, + {"name": "r6", "data": self.data6}, + ], + ) + + def _test_index_criteria(self, crit, expected, test_literal=True): + self._criteria_fixture() + with config.db.connect() as conn: + stmt = select(self.tables.data_table.c.name).where(crit) + + eq_(conn.scalar(stmt), expected) + + if test_literal: + literal_sql = str( + stmt.compile( + config.db, compile_kwargs={"literal_binds": True} + ) + ) + + eq_(conn.exec_driver_sql(literal_sql).scalar(), expected) + + def test_string_cast_crit_spaces_in_key(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + # limit the rows here to avoid PG error + # "cannot extract field from a non-object", which is + # fixed in 9.4 but may exist in 9.3 + self._test_index_criteria( + and_( + name.in_(["r1", "r2", "r3"]), + cast(col["key two"], String) == '"value2"', + ), + "r2", + ) + + @config.requirements.json_array_indexes + def test_string_cast_crit_simple_int(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + # limit the rows here to avoid PG error + # "cannot extract array element from a non-array", which is + # fixed in 9.4 but may exist in 9.3 + self._test_index_criteria( + and_( + name == "r4", + cast(col[1], String) == '"two"', + ), + "r4", + ) + + def test_string_cast_crit_mixed_path(self): + col = self.tables.data_table.c["data"] + self._test_index_criteria( + cast(col[("key3", 1, "six")], String) == '"seven"', + "r3", + ) + + def test_string_cast_crit_string_path(self): + col = self.tables.data_table.c["data"] + self._test_index_criteria( + cast(col[("nested", "elem2", "elem3", "elem4")], String) + == '"elem5"', + "r5", + ) + + def test_string_cast_crit_against_string_basic(self): + name = self.tables.data_table.c.name + col = self.tables.data_table.c["data"] + + self._test_index_criteria( + and_( + name == "r6", + cast(col["b"], String) == '"some value"', + ), + "r6", + ) + + +__all__ = ( + "BinaryTest", + "UnicodeVarcharTest", + "UnicodeTextTest", + "JSONTest", + "JSONLegacyStringCastIndexTest", + "DateTest", + "DateTimeTest", + "DateTimeTZTest", + "TextTest", + "NumericTest", + "IntegerTest", + "CastTypeDecoratorTest", + "DateTimeHistoricTest", + "DateTimeCoercedToDateTimeTest", + "TimeMicrosecondsTest", + "TimestampMicrosecondsTest", + "TimeTest", + "TimeTZTest", + "DateTimeMicrosecondsTest", + "DateHistoricTest", + "StringTest", + "BooleanTest", +) diff --git a/lib/sqlalchemy/testing/suite/test_unicode_ddl.py b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py new file mode 100644 index 0000000..a4ae334 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_unicode_ddl.py @@ -0,0 +1,206 @@ +# coding: utf-8 +"""verrrrry basic unicode column name testing""" + +from sqlalchemy import desc +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import testing +from sqlalchemy import util +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table +from sqlalchemy.util import u +from sqlalchemy.util import ue + + +class UnicodeSchemaTest(fixtures.TablesTest): + __requires__ = ("unicode_ddl",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + global t1, t2, t3 + + t1 = Table( + u("unitable1"), + metadata, + Column(u("méil"), Integer, primary_key=True), + Column(ue("\u6e2c\u8a66"), Integer), + test_needs_fk=True, + ) + t2 = Table( + u("Unitéble2"), + metadata, + Column(u("méil"), Integer, primary_key=True, key="a"), + Column( + ue("\u6e2c\u8a66"), + Integer, + ForeignKey(u("unitable1.méil")), + key="b", + ), + test_needs_fk=True, + ) + + # Few DBs support Unicode foreign keys + if testing.against("sqlite"): + t3 = Table( + ue("\u6e2c\u8a66"), + metadata, + Column( + ue("\u6e2c\u8a66_id"), + Integer, + primary_key=True, + autoincrement=False, + ), + Column( + ue("unitable1_\u6e2c\u8a66"), + Integer, + ForeignKey(ue("unitable1.\u6e2c\u8a66")), + ), + Column( + u("Unitéble2_b"), Integer, ForeignKey(u("Unitéble2.b")) + ), + Column( + ue("\u6e2c\u8a66_self"), + Integer, + ForeignKey(ue("\u6e2c\u8a66.\u6e2c\u8a66_id")), + ), + test_needs_fk=True, + ) + else: + t3 = Table( + ue("\u6e2c\u8a66"), + metadata, + Column( + ue("\u6e2c\u8a66_id"), + Integer, + primary_key=True, + autoincrement=False, + ), + Column(ue("unitable1_\u6e2c\u8a66"), Integer), + Column(u("Unitéble2_b"), Integer), + Column(ue("\u6e2c\u8a66_self"), Integer), + test_needs_fk=True, + ) + + def test_insert(self, connection): + connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + connection.execute(t2.insert(), {u("a"): 1, u("b"): 1}) + connection.execute( + t3.insert(), + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + }, + ) + + eq_(connection.execute(t1.select()).fetchall(), [(1, 5)]) + eq_(connection.execute(t2.select()).fetchall(), [(1, 1)]) + eq_(connection.execute(t3.select()).fetchall(), [(1, 5, 1, 1)]) + + def test_col_targeting(self, connection): + connection.execute(t1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + connection.execute(t2.insert(), {u("a"): 1, u("b"): 1}) + connection.execute( + t3.insert(), + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + }, + ) + + row = connection.execute(t1.select()).first() + eq_(row._mapping[t1.c[u("méil")]], 1) + eq_(row._mapping[t1.c[ue("\u6e2c\u8a66")]], 5) + + row = connection.execute(t2.select()).first() + eq_(row._mapping[t2.c[u("a")]], 1) + eq_(row._mapping[t2.c[u("b")]], 1) + + row = connection.execute(t3.select()).first() + eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_id")]], 1) + eq_(row._mapping[t3.c[ue("unitable1_\u6e2c\u8a66")]], 5) + eq_(row._mapping[t3.c[u("Unitéble2_b")]], 1) + eq_(row._mapping[t3.c[ue("\u6e2c\u8a66_self")]], 1) + + def test_reflect(self, connection): + connection.execute(t1.insert(), {u("méil"): 2, ue("\u6e2c\u8a66"): 7}) + connection.execute(t2.insert(), {u("a"): 2, u("b"): 2}) + connection.execute( + t3.insert(), + { + ue("\u6e2c\u8a66_id"): 2, + ue("unitable1_\u6e2c\u8a66"): 7, + u("Unitéble2_b"): 2, + ue("\u6e2c\u8a66_self"): 2, + }, + ) + + meta = MetaData() + tt1 = Table(t1.name, meta, autoload_with=connection) + tt2 = Table(t2.name, meta, autoload_with=connection) + tt3 = Table(t3.name, meta, autoload_with=connection) + + connection.execute(tt1.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 5}) + connection.execute(tt2.insert(), {u("méil"): 1, ue("\u6e2c\u8a66"): 1}) + connection.execute( + tt3.insert(), + { + ue("\u6e2c\u8a66_id"): 1, + ue("unitable1_\u6e2c\u8a66"): 5, + u("Unitéble2_b"): 1, + ue("\u6e2c\u8a66_self"): 1, + }, + ) + + eq_( + connection.execute( + tt1.select().order_by(desc(u("méil"))) + ).fetchall(), + [(2, 7), (1, 5)], + ) + eq_( + connection.execute( + tt2.select().order_by(desc(u("méil"))) + ).fetchall(), + [(2, 2), (1, 1)], + ) + eq_( + connection.execute( + tt3.select().order_by(desc(ue("\u6e2c\u8a66_id"))) + ).fetchall(), + [(2, 7, 2, 2), (1, 5, 1, 1)], + ) + + def test_repr(self): + meta = MetaData() + t = Table( + ue("\u6e2c\u8a66"), meta, Column(ue("\u6e2c\u8a66_id"), Integer) + ) + + if util.py2k: + eq_( + repr(t), + ( + "Table('\\u6e2c\\u8a66', MetaData(), " + "Column('\\u6e2c\\u8a66_id', Integer(), " + "table=<\u6e2c\u8a66>), " + "schema=None)" + ), + ) + else: + eq_( + repr(t), + ( + "Table('測試', MetaData(), " + "Column('測試_id', Integer(), " + "table=<測試>), " + "schema=None)" + ), + ) diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py new file mode 100644 index 0000000..f04a9d5 --- /dev/null +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -0,0 +1,60 @@ +from .. import fixtures +from ..assertions import eq_ +from ..schema import Column +from ..schema import Table +from ... import Integer +from ... import String + + +class SimpleUpdateDeleteTest(fixtures.TablesTest): + run_deletes = "each" + __requires__ = ("sane_rowcount",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "plain_pk", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.plain_pk.insert(), + [ + {"id": 1, "data": "d1"}, + {"id": 2, "data": "d2"}, + {"id": 3, "data": "d3"}, + ], + ) + + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute( + t.update().where(t.c.id == 2), dict(data="d2_new") + ) + assert not r.is_insert + assert not r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + assert not r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +__all__ = ("SimpleUpdateDeleteTest",) |