diff options
Diffstat (limited to 'lib/sqlalchemy/testing/schema.py')
-rw-r--r-- | lib/sqlalchemy/testing/schema.py | 218 |
1 files changed, 218 insertions, 0 deletions
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py new file mode 100644 index 0000000..bff07a5 --- /dev/null +++ b/lib/sqlalchemy/testing/schema.py @@ -0,0 +1,218 @@ +# testing/schema.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import sys + +from . import config +from . import exclusions +from .. import event +from .. import schema +from .. import types as sqltypes +from ..util import OrderedDict + + +__all__ = ["Table", "Column"] + +table_options = {} + + +def Table(*args, **kw): + """A schema.Table wrapper/hook for dialect-specific tweaks.""" + + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} + + kw.update(table_options) + + if exclusions.against(config._current, "mysql"): + if ( + "mysql_engine" not in kw + and "mysql_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mysql_engine"] = "InnoDB" + else: + kw["mysql_engine"] = "MyISAM" + elif exclusions.against(config._current, "mariadb"): + if ( + "mariadb_engine" not in kw + and "mariadb_type" not in kw + and "autoload_with" not in kw + ): + if "test_needs_fk" in test_opts or "test_needs_acid" in test_opts: + kw["mariadb_engine"] = "InnoDB" + else: + kw["mariadb_engine"] = "MyISAM" + + # Apply some default cascading rules for self-referential foreign keys. + # MySQL InnoDB has some issues around selecting self-refs too. + if exclusions.against(config._current, "firebird"): + table_name = args[0] + unpack = config.db.dialect.identifier_preparer.unformat_identifiers + + # Only going after ForeignKeys in Columns. May need to + # expand to ForeignKeyConstraint too. + fks = [ + fk + for col in args + if isinstance(col, schema.Column) + for fk in col.foreign_keys + ] + + for fk in fks: + # root around in raw spec + ref = fk._colspec + if isinstance(ref, schema.Column): + name = ref.table.name + else: + # take just the table name: on FB there cannot be + # a schema, so the first element is always the + # table name, possibly followed by the field name + name = unpack(ref)[0] + if name == table_name: + if fk.ondelete is None: + fk.ondelete = "CASCADE" + if fk.onupdate is None: + fk.onupdate = "CASCADE" + + return schema.Table(*args, **kw) + + +def Column(*args, **kw): + """A schema.Column wrapper/hook for dialect-specific tweaks.""" + + test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} + + if not config.requirements.foreign_key_ddl.enabled_for_config(config): + args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] + + col = schema.Column(*args, **kw) + if test_opts.get("test_needs_autoincrement", False) and kw.get( + "primary_key", False + ): + + if col.default is None and col.server_default is None: + col.autoincrement = True + + # allow any test suite to pick up on this + col.info["test_needs_autoincrement"] = True + + # hardcoded rule for firebird, oracle; this should + # be moved out + if exclusions.against(config._current, "firebird", "oracle"): + + def add_seq(c, tbl): + c._init_items( + schema.Sequence( + _truncate_name( + config.db.dialect, tbl.name + "_" + c.name + "_seq" + ), + optional=True, + ) + ) + + event.listen(col, "after_parent_attach", add_seq, propagate=True) + return col + + +class eq_type_affinity(object): + """Helper to compare types inside of datastructures based on affinity. + + E.g.:: + + eq_( + inspect(connection).get_columns("foo"), + [ + { + "name": "id", + "type": testing.eq_type_affinity(sqltypes.INTEGER), + "nullable": False, + "default": None, + "autoincrement": False, + }, + { + "name": "data", + "type": testing.eq_type_affinity(sqltypes.NullType), + "nullable": True, + "default": None, + "autoincrement": False, + }, + ], + ) + + """ + + def __init__(self, target): + self.target = sqltypes.to_instance(target) + + def __eq__(self, other): + return self.target._type_affinity is other._type_affinity + + def __ne__(self, other): + return self.target._type_affinity is not other._type_affinity + + +class eq_clause_element(object): + """Helper to compare SQL structures based on compare()""" + + def __init__(self, target): + self.target = target + + def __eq__(self, other): + return self.target.compare(other) + + def __ne__(self, other): + return not self.target.compare(other) + + +def _truncate_name(dialect, name): + if len(name) > dialect.max_identifier_length: + return ( + name[0 : max(dialect.max_identifier_length - 6, 0)] + + "_" + + hex(hash(name) % 64)[2:] + ) + else: + return name + + +def pep435_enum(name): + # Implements PEP 435 in the minimal fashion needed by SQLAlchemy + __members__ = OrderedDict() + + def __init__(self, name, value, alias=None): + self.name = name + self.value = value + self.__members__[name] = self + value_to_member[value] = self + setattr(self.__class__, name, self) + if alias: + self.__members__[alias] = self + setattr(self.__class__, alias, self) + + value_to_member = {} + + @classmethod + def get(cls, value): + return value_to_member[value] + + someenum = type( + name, + (object,), + {"__members__": __members__, "__init__": __init__, "get": get}, + ) + + # getframe() trick for pickling I don't understand courtesy + # Python namedtuple() + try: + module = sys._getframe(1).f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + if module is not None: + someenum.__module__ = module + + return someenum |