diff options
Diffstat (limited to 'lib/sqlalchemy/dialects')
65 files changed, 34330 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py new file mode 100644 index 0000000..84a9ad8 --- /dev/null +++ b/lib/sqlalchemy/dialects/__init__.py @@ -0,0 +1,72 @@ +# dialects/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +__all__ = ( + "firebird", + "mssql", + "mysql", + "oracle", + "postgresql", + "sqlite", + "sybase", +) + + +from .. import util + + +def _auto_fn(name): + """default dialect importer. + + plugs into the :class:`.PluginLoader` + as a first-hit system. + + """ + if "." in name: + dialect, driver = name.split(".") + else: + dialect = name + driver = "base" + + try: + if dialect == "firebird": + try: + module = __import__("sqlalchemy_firebird") + except ImportError: + module = __import__("sqlalchemy.dialects.firebird").dialects + module = getattr(module, dialect) + elif dialect == "sybase": + try: + module = __import__("sqlalchemy_sybase") + except ImportError: + module = __import__("sqlalchemy.dialects.sybase").dialects + module = getattr(module, dialect) + elif dialect == "mariadb": + # it's "OK" for us to hardcode here since _auto_fn is already + # hardcoded. if mysql / mariadb etc were third party dialects + # they would just publish all the entrypoints, which would actually + # look much nicer. + module = __import__( + "sqlalchemy.dialects.mysql.mariadb" + ).dialects.mysql.mariadb + return module.loader(driver) + else: + module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects + module = getattr(module, dialect) + except ImportError: + return None + + if hasattr(module, driver): + module = getattr(module, driver) + return lambda: module.dialect + else: + return None + + +registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn) + +plugins = util.PluginLoader("sqlalchemy.plugins") diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py new file mode 100644 index 0000000..a34eecf --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -0,0 +1,41 @@ +# firebird/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from sqlalchemy.dialects.firebird.base import BIGINT +from sqlalchemy.dialects.firebird.base import BLOB +from sqlalchemy.dialects.firebird.base import CHAR +from sqlalchemy.dialects.firebird.base import DATE +from sqlalchemy.dialects.firebird.base import FLOAT +from sqlalchemy.dialects.firebird.base import NUMERIC +from sqlalchemy.dialects.firebird.base import SMALLINT +from sqlalchemy.dialects.firebird.base import TEXT +from sqlalchemy.dialects.firebird.base import TIME +from sqlalchemy.dialects.firebird.base import TIMESTAMP +from sqlalchemy.dialects.firebird.base import VARCHAR +from . import base # noqa +from . import fdb # noqa +from . import kinterbasdb # noqa + + +base.dialect = dialect = fdb.dialect + +__all__ = ( + "SMALLINT", + "BIGINT", + "FLOAT", + "FLOAT", + "DATE", + "TIME", + "TEXT", + "NUMERIC", + "FLOAT", + "TIMESTAMP", + "VARCHAR", + "CHAR", + "BLOB", + "dialect", +) diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py new file mode 100644 index 0000000..e2698b1 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -0,0 +1,989 @@ +# firebird/base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" + +.. dialect:: firebird + :name: Firebird + +.. note:: + + The Firebird dialect within SQLAlchemy **is not currently supported**. + It is not tested within continuous integration and is likely to have + many issues and caveats not currently handled. Consider using the + `external dialect <https://github.com/pauldex/sqlalchemy-firebird>`_ + instead. + +.. deprecated:: 1.4 The internal Firebird dialect is deprecated and will be + removed in a future version. Use the external dialect. + +Firebird Dialects +----------------- + +Firebird offers two distinct dialects_ (not to be confused with a +SQLAlchemy ``Dialect``): + +dialect 1 + This is the old syntax and behaviour, inherited from Interbase pre-6.0. + +dialect 3 + This is the newer and supported syntax, introduced in Interbase 6.0. + +The SQLAlchemy Firebird dialect detects these versions and +adjusts its representation of SQL accordingly. However, +support for dialect 1 is not well tested and probably has +incompatibilities. + +Locking Behavior +---------------- + +Firebird locks tables aggressively. For this reason, a DROP TABLE may +hang until other transactions are released. SQLAlchemy does its best +to release transactions as quickly as possible. The most common cause +of hanging transactions is a non-fully consumed result set, i.e.:: + + result = engine.execute(text("select * from table")) + row = result.fetchone() + return + +Where above, the ``CursorResult`` has not been fully consumed. The +connection will be returned to the pool and the transactional state +rolled back once the Python garbage collector reclaims the objects +which hold onto the connection, which often occurs asynchronously. +The above use case can be alleviated by calling ``first()`` on the +``CursorResult`` which will fetch the first row and immediately close +all remaining cursor/connection resources. + +RETURNING support +----------------- + +Firebird 2.0 supports returning a result set from inserts, and 2.1 +extends that to deletes and updates. This is generically exposed by +the SQLAlchemy ``returning()`` method, such as:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\ + values(name='foo') + print(result.fetchall()) + + # UPDATE..RETURNING + raises = empl.update().returning(empl.c.id, empl.c.salary).\ + where(empl.c.sales>100).\ + values(dict(salary=empl.c.salary * 1.1)) + print(raises.fetchall()) + + +.. _dialects: https://mc-computing.com/Databases/Firebird/SQL_Dialect.html +""" + +import datetime + +from sqlalchemy import exc +from sqlalchemy import sql +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from sqlalchemy.engine import default +from sqlalchemy.engine import reflection +from sqlalchemy.sql import compiler +from sqlalchemy.sql import expression +from sqlalchemy.types import BIGINT +from sqlalchemy.types import BLOB +from sqlalchemy.types import DATE +from sqlalchemy.types import FLOAT +from sqlalchemy.types import INTEGER +from sqlalchemy.types import Integer +from sqlalchemy.types import NUMERIC +from sqlalchemy.types import SMALLINT +from sqlalchemy.types import TEXT +from sqlalchemy.types import TIME +from sqlalchemy.types import TIMESTAMP + + +RESERVED_WORDS = set( + [ + "active", + "add", + "admin", + "after", + "all", + "alter", + "and", + "any", + "as", + "asc", + "ascending", + "at", + "auto", + "avg", + "before", + "begin", + "between", + "bigint", + "bit_length", + "blob", + "both", + "by", + "case", + "cast", + "char", + "character", + "character_length", + "char_length", + "check", + "close", + "collate", + "column", + "commit", + "committed", + "computed", + "conditional", + "connect", + "constraint", + "containing", + "count", + "create", + "cross", + "cstring", + "current", + "current_connection", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_transaction", + "current_user", + "cursor", + "database", + "date", + "day", + "dec", + "decimal", + "declare", + "default", + "delete", + "desc", + "descending", + "disconnect", + "distinct", + "do", + "domain", + "double", + "drop", + "else", + "end", + "entry_point", + "escape", + "exception", + "execute", + "exists", + "exit", + "external", + "extract", + "fetch", + "file", + "filter", + "float", + "for", + "foreign", + "from", + "full", + "function", + "gdscode", + "generator", + "gen_id", + "global", + "grant", + "group", + "having", + "hour", + "if", + "in", + "inactive", + "index", + "inner", + "input_type", + "insensitive", + "insert", + "int", + "integer", + "into", + "is", + "isolation", + "join", + "key", + "leading", + "left", + "length", + "level", + "like", + "long", + "lower", + "manual", + "max", + "maximum_segment", + "merge", + "min", + "minute", + "module_name", + "month", + "names", + "national", + "natural", + "nchar", + "no", + "not", + "null", + "numeric", + "octet_length", + "of", + "on", + "only", + "open", + "option", + "or", + "order", + "outer", + "output_type", + "overflow", + "page", + "pages", + "page_size", + "parameter", + "password", + "plan", + "position", + "post_event", + "precision", + "primary", + "privileges", + "procedure", + "protected", + "rdb$db_key", + "read", + "real", + "record_version", + "recreate", + "recursive", + "references", + "release", + "reserv", + "reserving", + "retain", + "returning_values", + "returns", + "revoke", + "right", + "rollback", + "rows", + "row_count", + "savepoint", + "schema", + "second", + "segment", + "select", + "sensitive", + "set", + "shadow", + "shared", + "singular", + "size", + "smallint", + "snapshot", + "some", + "sort", + "sqlcode", + "stability", + "start", + "starting", + "starts", + "statistics", + "sub_type", + "sum", + "suspend", + "table", + "then", + "time", + "timestamp", + "to", + "trailing", + "transaction", + "trigger", + "trim", + "uncommitted", + "union", + "unique", + "update", + "upper", + "user", + "using", + "value", + "values", + "varchar", + "variable", + "varying", + "view", + "wait", + "when", + "where", + "while", + "with", + "work", + "write", + "year", + ] +) + + +class _StringType(sqltypes.String): + """Base for Firebird string types.""" + + def __init__(self, charset=None, **kw): + self.charset = charset + super(_StringType, self).__init__(**kw) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """Firebird VARCHAR type""" + + __visit_name__ = "VARCHAR" + + def __init__(self, length=None, **kwargs): + super(VARCHAR, self).__init__(length=length, **kwargs) + + +class CHAR(_StringType, sqltypes.CHAR): + """Firebird CHAR type""" + + __visit_name__ = "CHAR" + + def __init__(self, length=None, **kwargs): + super(CHAR, self).__init__(length=length, **kwargs) + + +class _FBDateTime(sqltypes.DateTime): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + + return process + + +colspecs = {sqltypes.DateTime: _FBDateTime} + +ischema_names = { + "SHORT": SMALLINT, + "LONG": INTEGER, + "QUAD": FLOAT, + "FLOAT": FLOAT, + "DATE": DATE, + "TIME": TIME, + "TEXT": TEXT, + "INT64": BIGINT, + "DOUBLE": FLOAT, + "TIMESTAMP": TIMESTAMP, + "VARYING": VARCHAR, + "CSTRING": CHAR, + "BLOB": BLOB, +} + + +# TODO: date conversion types (should be implemented as _FBDateTime, +# _FBDate, etc. as bind/result functionality is required) + + +class FBTypeCompiler(compiler.GenericTypeCompiler): + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + def visit_TEXT(self, type_, **kw): + return "BLOB SUB_TYPE 1" + + def visit_BLOB(self, type_, **kw): + return "BLOB SUB_TYPE 0" + + def _extend_string(self, type_, basic): + charset = getattr(type_, "charset", None) + if charset is None: + return basic + else: + return "%s CHARACTER SET %s" % (basic, charset) + + def visit_CHAR(self, type_, **kw): + basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw) + return self._extend_string(type_, basic) + + def visit_VARCHAR(self, type_, **kw): + if not type_.length: + raise exc.CompileError( + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) + basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw) + return self._extend_string(type_, basic) + + +class FBCompiler(sql.compiler.SQLCompiler): + """Firebird specific idiosyncrasies""" + + ansi_bind_rules = True + + # def visit_contains_op_binary(self, binary, operator, **kw): + # cant use CONTAINING b.c. it's case insensitive. + + # def visit_not_contains_op_binary(self, binary, operator, **kw): + # cant use NOT CONTAINING b.c. it's case insensitive. + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_startswith_op_binary(self, binary, operator, **kw): + return "%s STARTING WITH %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + + def visit_not_startswith_op_binary(self, binary, operator, **kw): + return "%s NOT STARTING WITH %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + + def visit_mod_binary(self, binary, operator, **kw): + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_alias(self, alias, asfrom=False, **kwargs): + if self.dialect._version_two: + return super(FBCompiler, self).visit_alias( + alias, asfrom=asfrom, **kwargs + ) + else: + # Override to not use the AS keyword which FB 1.5 does not like + if asfrom: + alias_name = ( + isinstance(alias.name, expression._truncated_label) + and self._truncated_identifier("alias", alias.name) + or alias.name + ) + + return ( + self.process(alias.element, asfrom=asfrom, **kwargs) + + " " + + self.preparer.format_alias(alias, alias_name) + ) + else: + return self.process(alias.element, **kwargs) + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0]) + start = self.process(func.clauses.clauses[1]) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2]) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def visit_length_func(self, function, **kw): + if self.dialect._version_two: + return "char_length" + self.function_argspec(function) + else: + return "strlen" + self.function_argspec(function) + + visit_char_length_func = visit_length_func + + def function_argspec(self, func, **kw): + # TODO: this probably will need to be + # narrowed to a fixed list, some no-arg functions + # may require parens - see similar example in the oracle + # dialect + if func.clauses is not None and len(func.clauses): + return self.process(func.clause_expr, **kw) + else: + return "" + + def default_from(self): + return " FROM rdb$database" + + def visit_sequence(self, seq, **kw): + return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) + + def get_select_precolumns(self, select, **kw): + """Called when building a ``SELECT`` statement, position is just + before column list Firebird puts the limit and offset right + after the ``SELECT``... + """ + + result = "" + if select._limit_clause is not None: + result += "FIRST %s " % self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + result += "SKIP %s " % self.process(select._offset_clause, **kw) + result += super(FBCompiler, self).get_select_precolumns(select, **kw) + return result + + def limit_clause(self, select, **kw): + """Already taken care of in the `get_select_precolumns` method.""" + + return "" + + def returning_clause(self, stmt, returning_cols): + columns = [ + self._label_returning_column(stmt, c) + for c in expression._select_iterables(returning_cols) + ] + + return "RETURNING " + ", ".join(columns) + + +class FBDDLCompiler(sql.compiler.DDLCompiler): + """Firebird syntactic idiosyncrasies""" + + def visit_create_sequence(self, create): + """Generate a ``CREATE GENERATOR`` statement for the sequence.""" + + # no syntax for these + # https://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html + if create.element.start is not None: + raise NotImplementedError( + "Firebird SEQUENCE doesn't support START WITH" + ) + if create.element.increment is not None: + raise NotImplementedError( + "Firebird SEQUENCE doesn't support INCREMENT BY" + ) + + if self.dialect._version_two: + return "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) + else: + return "CREATE GENERATOR %s" % self.preparer.format_sequence( + create.element + ) + + def visit_drop_sequence(self, drop): + """Generate a ``DROP GENERATOR`` statement for the sequence.""" + + if self.dialect._version_two: + return "DROP SEQUENCE %s" % self.preparer.format_sequence( + drop.element + ) + else: + return "DROP GENERATOR %s" % self.preparer.format_sequence( + drop.element + ) + + def visit_computed_column(self, generated): + if generated.persisted is not None: + raise exc.CompileError( + "Firebird computed columns do not support a persistence " + "method setting; set the 'persisted' flag to None for " + "Firebird support." + ) + return "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + + +class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): + """Install Firebird specific reserved words.""" + + reserved_words = RESERVED_WORDS + illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union( + ["_"] + ) + + def __init__(self, dialect): + super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) + + +class FBExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + """Get the next value from the sequence using ``gen_id()``.""" + + return self._execute_scalar( + "SELECT gen_id(%s, 1) FROM rdb$database" + % self.identifier_preparer.format_sequence(seq), + type_, + ) + + +class FBDialect(default.DefaultDialect): + """Firebird dialect""" + + name = "firebird" + supports_statement_cache = True + + max_identifier_length = 31 + + supports_sequences = True + sequences_optional = False + supports_default_values = True + postfetch_lastrowid = False + + supports_native_boolean = False + + requires_name_normalize = True + supports_empty_insert = False + + statement_compiler = FBCompiler + ddl_compiler = FBDDLCompiler + preparer = FBIdentifierPreparer + type_compiler = FBTypeCompiler + execution_ctx_cls = FBExecutionContext + + colspecs = colspecs + ischema_names = ischema_names + + construct_arguments = [] + + # defaults to dialect ver. 3, + # will be autodetected off upon + # first connect + _version_two = True + + def __init__(self, *args, **kwargs): + util.warn_deprecated( + "The firebird dialect is deprecated and will be removed " + "in a future version. This dialect is superseded by the external " + "dialect https://github.com/pauldex/sqlalchemy-firebird.", + version="1.4", + ) + super(FBDialect, self).__init__(*args, **kwargs) + + def initialize(self, connection): + super(FBDialect, self).initialize(connection) + self._version_two = ( + "firebird" in self.server_version_info + and self.server_version_info >= (2,) + ) or ( + "interbase" in self.server_version_info + and self.server_version_info >= (6,) + ) + + if not self._version_two: + # TODO: whatever other pre < 2.0 stuff goes here + self.ischema_names = ischema_names.copy() + self.ischema_names["TIMESTAMP"] = sqltypes.DATE + self.colspecs = {sqltypes.DateTime: sqltypes.DATE} + + self.implicit_returning = self._version_two and self.__dict__.get( + "implicit_returning", True + ) + + def has_table(self, connection, table_name, schema=None): + """Return ``True`` if the given table exists, ignoring + the `schema`.""" + self._ensure_has_table_connection(connection) + + tblqry = """ + SELECT 1 AS has_table FROM rdb$database + WHERE EXISTS (SELECT rdb$relation_name + FROM rdb$relations + WHERE rdb$relation_name=?) + """ + c = connection.exec_driver_sql( + tblqry, [self.denormalize_name(table_name)] + ) + return c.first() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + """Return ``True`` if the given sequence (generator) exists.""" + + genqry = """ + SELECT 1 AS has_sequence FROM rdb$database + WHERE EXISTS (SELECT rdb$generator_name + FROM rdb$generators + WHERE rdb$generator_name=?) + """ + c = connection.exec_driver_sql( + genqry, [self.denormalize_name(sequence_name)] + ) + return c.first() is not None + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + # there are two queries commonly mentioned for this. + # this one, using view_blr, is at the Firebird FAQ among other places: + # https://www.firebirdfaq.org/faq174/ + s = """ + select rdb$relation_name + from rdb$relations + where rdb$view_blr is null + and (rdb$system_flag is null or rdb$system_flag = 0); + """ + + # the other query is this one. It's not clear if there's really + # any difference between these two. This link: + # https://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8 + # states them as interchangeable. Some discussion at [ticket:2898] + # SELECT DISTINCT rdb$relation_name + # FROM rdb$relation_fields + # WHERE rdb$system_flag=0 AND rdb$view_context IS NULL + + return [ + self.normalize_name(row[0]) + for row in connection.exec_driver_sql(s) + ] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + # see https://www.firebirdfaq.org/faq174/ + s = """ + select rdb$relation_name + from rdb$relations + where rdb$view_blr is not null + and (rdb$system_flag is null or rdb$system_flag = 0); + """ + return [ + self.normalize_name(row[0]) + for row in connection.exec_driver_sql(s) + ] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + qry = """ + SELECT rdb$view_source AS view_source + FROM rdb$relations + WHERE rdb$relation_name=? + """ + rp = connection.exec_driver_sql( + qry, [self.denormalize_name(view_name)] + ) + row = rp.first() + if row: + return row["view_source"] + else: + return None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + # Query to extract the PK/FK constrained fields of the given table + keyqry = """ + SELECT se.rdb$field_name AS fname + FROM rdb$relation_constraints rc + JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + """ + tablename = self.denormalize_name(table_name) + # get primary key fields + c = connection.exec_driver_sql(keyqry, ["PRIMARY KEY", tablename]) + pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()] + return {"constrained_columns": pkfields, "name": None} + + @reflection.cache + def get_column_sequence( + self, connection, table_name, column_name, schema=None, **kw + ): + tablename = self.denormalize_name(table_name) + colname = self.denormalize_name(column_name) + # Heuristic-query to determine the generator associated to a PK field + genqry = """ + SELECT trigdep.rdb$depended_on_name AS fgenerator + FROM rdb$dependencies tabdep + JOIN rdb$dependencies trigdep + ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name + AND trigdep.rdb$depended_on_type=14 + AND trigdep.rdb$dependent_type=2 + JOIN rdb$triggers trig ON + trig.rdb$trigger_name=tabdep.rdb$dependent_name + WHERE tabdep.rdb$depended_on_name=? + AND tabdep.rdb$depended_on_type=0 + AND trig.rdb$trigger_type=1 + AND tabdep.rdb$field_name=? + AND (SELECT count(*) + FROM rdb$dependencies trigdep2 + WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 + """ + genr = connection.exec_driver_sql(genqry, [tablename, colname]).first() + if genr is not None: + return dict(name=self.normalize_name(genr["fgenerator"])) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + # Query to extract the details of all the fields of the given table + tblqry = """ + SELECT r.rdb$field_name AS fname, + r.rdb$null_flag AS null_flag, + t.rdb$type_name AS ftype, + f.rdb$field_sub_type AS stype, + f.rdb$field_length/ + COALESCE(cs.rdb$bytes_per_character,1) AS flen, + f.rdb$field_precision AS fprec, + f.rdb$field_scale AS fscale, + COALESCE(r.rdb$default_source, + f.rdb$default_source) AS fdefault + FROM rdb$relation_fields r + JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name + JOIN rdb$types t + ON t.rdb$type=f.rdb$field_type AND + t.rdb$field_name='RDB$FIELD_TYPE' + LEFT JOIN rdb$character_sets cs ON + f.rdb$character_set_id=cs.rdb$character_set_id + WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? + ORDER BY r.rdb$field_position + """ + # get the PK, used to determine the eventual associated sequence + pk_constraint = self.get_pk_constraint(connection, table_name) + pkey_cols = pk_constraint["constrained_columns"] + + tablename = self.denormalize_name(table_name) + # get all of the fields for this table + c = connection.exec_driver_sql(tblqry, [tablename]) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + name = self.normalize_name(row["fname"]) + orig_colname = row["fname"] + + # get the data type + colspec = row["ftype"].rstrip() + coltype = self.ischema_names.get(colspec) + if coltype is None: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (colspec, name) + ) + coltype = sqltypes.NULLTYPE + elif issubclass(coltype, Integer) and row["fprec"] != 0: + coltype = NUMERIC( + precision=row["fprec"], scale=row["fscale"] * -1 + ) + elif colspec in ("VARYING", "CSTRING"): + coltype = coltype(row["flen"]) + elif colspec == "TEXT": + coltype = TEXT(row["flen"]) + elif colspec == "BLOB": + if row["stype"] == 1: + coltype = TEXT() + else: + coltype = BLOB() + else: + coltype = coltype() + + # does it have a default value? + defvalue = None + if row["fdefault"] is not None: + # the value comes down as "DEFAULT 'value'": there may be + # more than one whitespace around the "DEFAULT" keyword + # and it may also be lower case + # (see also https://tracker.firebirdsql.org/browse/CORE-356) + defexpr = row["fdefault"].lstrip() + assert defexpr[:8].rstrip().upper() == "DEFAULT", ( + "Unrecognized default value: %s" % defexpr + ) + defvalue = defexpr[8:].strip() + if defvalue == "NULL": + # Redundant + defvalue = None + col_d = { + "name": name, + "type": coltype, + "nullable": not bool(row["null_flag"]), + "default": defvalue, + "autoincrement": "auto", + } + + if orig_colname.lower() == orig_colname: + col_d["quote"] = True + + # if the PK is a single field, try to see if its linked to + # a sequence thru a trigger + if len(pkey_cols) == 1 and name == pkey_cols[0]: + seq_d = self.get_column_sequence(connection, tablename, name) + if seq_d is not None: + col_d["sequence"] = seq_d + + cols.append(col_d) + return cols + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the details of each UK/FK of the given table + fkqry = """ + SELECT rc.rdb$constraint_name AS cname, + cse.rdb$field_name AS fname, + ix2.rdb$relation_name AS targetrname, + se.rdb$field_name AS targetfname + FROM rdb$relation_constraints rc + JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name + JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key + JOIN rdb$index_segments cse ON + cse.rdb$index_name=ix1.rdb$index_name + JOIN rdb$index_segments se + ON se.rdb$index_name=ix2.rdb$index_name + AND se.rdb$field_position=cse.rdb$field_position + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + ORDER BY se.rdb$index_name, se.rdb$field_position + """ + tablename = self.denormalize_name(table_name) + + c = connection.exec_driver_sql(fkqry, ["FOREIGN KEY", tablename]) + fks = util.defaultdict( + lambda: { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + } + ) + + for row in c: + cname = self.normalize_name(row["cname"]) + fk = fks[cname] + if not fk["name"]: + fk["name"] = cname + fk["referred_table"] = self.normalize_name(row["targetrname"]) + fk["constrained_columns"].append(self.normalize_name(row["fname"])) + fk["referred_columns"].append( + self.normalize_name(row["targetfname"]) + ) + return list(fks.values()) + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + qry = """ + SELECT ix.rdb$index_name AS index_name, + ix.rdb$unique_flag AS unique_flag, + ic.rdb$field_name AS field_name + FROM rdb$indices ix + JOIN rdb$index_segments ic + ON ix.rdb$index_name=ic.rdb$index_name + LEFT OUTER JOIN rdb$relation_constraints + ON rdb$relation_constraints.rdb$index_name = + ic.rdb$index_name + WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL + AND rdb$relation_constraints.rdb$constraint_type IS NULL + ORDER BY index_name, ic.rdb$field_position + """ + c = connection.exec_driver_sql( + qry, [self.denormalize_name(table_name)] + ) + + indexes = util.defaultdict(dict) + for row in c: + indexrec = indexes[row["index_name"]] + if "name" not in indexrec: + indexrec["name"] = self.normalize_name(row["index_name"]) + indexrec["column_names"] = [] + indexrec["unique"] = bool(row["unique_flag"]) + + indexrec["column_names"].append( + self.normalize_name(row["field_name"]) + ) + + return list(indexes.values()) diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py new file mode 100644 index 0000000..38f4432 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/fdb.py @@ -0,0 +1,112 @@ +# firebird/fdb.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: firebird+fdb + :name: fdb + :dbapi: pyodbc + :connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...] + :url: https://pypi.org/project/fdb/ + + fdb is a kinterbasdb compatible DBAPI for Firebird. + + .. versionchanged:: 0.9 - The fdb dialect is now the default dialect + under the ``firebird://`` URL space, as ``fdb`` is now the official + Python driver for Firebird. + +Arguments +---------- + +The ``fdb`` dialect is based on the +:mod:`sqlalchemy.dialects.firebird.kinterbasdb` dialect, however does not +accept every argument that Kinterbasdb does. + +* ``enable_rowcount`` - True by default, setting this to False disables + the usage of "cursor.rowcount" with the + Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically + after any UPDATE or DELETE statement. When disabled, SQLAlchemy's + CursorResult will return -1 for result.rowcount. The rationale here is + that Kinterbasdb requires a second round trip to the database when + .rowcount is called - since SQLA's resultproxy automatically closes + the cursor after a non-result-returning statement, rowcount must be + called, if at all, before the result object is returned. Additionally, + cursor.rowcount may not return correct results with older versions + of Firebird, and setting this flag to False will also cause the + SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a + per-execution basis using the ``enable_rowcount`` option with + :meth:`_engine.Connection.execution_options`:: + + conn = engine.connect().execution_options(enable_rowcount=True) + r = conn.execute(stmt) + print(r.rowcount) + +* ``retaining`` - False by default. Setting this to True will pass the + ``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()`` + methods of the DBAPI connection, which can improve performance in some + situations, but apparently with significant caveats. + Please read the fdb and/or kinterbasdb DBAPI documentation in order to + understand the implications of this flag. + + .. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``. + In 0.8 it defaulted to ``True``. + + .. seealso:: + + https://pythonhosted.org/fdb/usage-guide.html#retaining-transactions + - information on the "retaining" flag. + +""" # noqa + +from .kinterbasdb import FBDialect_kinterbasdb +from ... import util + + +class FBDialect_fdb(FBDialect_kinterbasdb): + supports_statement_cache = True + + def __init__(self, enable_rowcount=True, retaining=False, **kwargs): + super(FBDialect_fdb, self).__init__( + enable_rowcount=enable_rowcount, retaining=retaining, **kwargs + ) + + @classmethod + def dbapi(cls): + return __import__("fdb") + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if opts.get("port"): + opts["host"] = "%s/%s" % (opts["host"], opts["port"]) + del opts["port"] + opts.update(url.query) + + util.coerce_kw_type(opts, "type_conv", int) + + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. + + isc_info_firebird_version = 103 + fbconn = connection.connection + + version = fbconn.db_info(isc_info_firebird_version) + + return self._parse_version_info(version) + + +dialect = FBDialect_fdb diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py new file mode 100644 index 0000000..b999404 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -0,0 +1,202 @@ +# firebird/kinterbasdb.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: firebird+kinterbasdb + :name: kinterbasdb + :dbapi: kinterbasdb + :connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...] + :url: https://firebirdsql.org/index.php?op=devel&sub=python + +Arguments +---------- + +The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining`` +arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect. +In addition, it also accepts the following: + +* ``type_conv`` - select the kind of mapping done on the types: by default + SQLAlchemy uses 200 with Unicode, datetime and decimal support. See + the linked documents below for further information. + +* ``concurrency_level`` - set the backend policy with regards to threading + issues: by default SQLAlchemy uses policy 1. See the linked documents + below for further information. + +.. seealso:: + + https://sourceforge.net/projects/kinterbasdb + + https://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation + + https://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency + +""" # noqa + +import decimal +from re import match + +from .base import FBDialect +from .base import FBExecutionContext +from ... import types as sqltypes +from ... import util + + +class _kinterbasdb_numeric(object): + def bind_processor(self, dialect): + def process(value): + if isinstance(value, decimal.Decimal): + return str(value) + else: + return value + + return process + + +class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric): + pass + + +class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float): + pass + + +class FBExecutionContext_kinterbasdb(FBExecutionContext): + @property + def rowcount(self): + if self.execution_options.get( + "enable_rowcount", self.dialect.enable_rowcount + ): + return self.cursor.rowcount + else: + return -1 + + +class FBDialect_kinterbasdb(FBDialect): + driver = "kinterbasdb" + supports_statement_cache = True + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + execution_ctx_cls = FBExecutionContext_kinterbasdb + + supports_native_decimal = True + + colspecs = util.update_copy( + FBDialect.colspecs, + { + sqltypes.Numeric: _FBNumeric_kinterbasdb, + sqltypes.Float: _FBFloat_kinterbasdb, + }, + ) + + def __init__( + self, + type_conv=200, + concurrency_level=1, + enable_rowcount=True, + retaining=False, + **kwargs + ): + super(FBDialect_kinterbasdb, self).__init__(**kwargs) + self.enable_rowcount = enable_rowcount + self.type_conv = type_conv + self.concurrency_level = concurrency_level + self.retaining = retaining + if enable_rowcount: + self.supports_sane_rowcount = True + + @classmethod + def dbapi(cls): + return __import__("kinterbasdb") + + def do_execute(self, cursor, statement, parameters, context=None): + # kinterbase does not accept a None, but wants an empty list + # when there are no arguments. + cursor.execute(statement, parameters or []) + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback(self.retaining) + + def do_commit(self, dbapi_connection): + dbapi_connection.commit(self.retaining) + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if opts.get("port"): + opts["host"] = "%s/%s" % (opts["host"], opts["port"]) + del opts["port"] + opts.update(url.query) + + util.coerce_kw_type(opts, "type_conv", int) + + type_conv = opts.pop("type_conv", self.type_conv) + concurrency_level = opts.pop( + "concurrency_level", self.concurrency_level + ) + + if self.dbapi is not None: + initialized = getattr(self.dbapi, "initialized", None) + if initialized is None: + # CVS rev 1.96 changed the name of the attribute: + # https://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/ + # Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96 + initialized = getattr(self.dbapi, "_initialized", False) + if not initialized: + self.dbapi.init( + type_conv=type_conv, concurrency_level=concurrency_level + ) + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. + + fbconn = connection.connection + version = fbconn.server_version + + return self._parse_version_info(version) + + def _parse_version_info(self, version): + m = match( + r"\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?", version + ) + if not m: + raise AssertionError( + "Could not determine version from string '%s'" % version + ) + + if m.group(5) != None: + return tuple([int(x) for x in m.group(6, 7, 4)] + ["firebird"]) + else: + return tuple([int(x) for x in m.group(1, 2, 3)] + ["interbase"]) + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): + msg = str(e) + return ( + "Error writing data to the connection" in msg + or "Unable to complete network request to host" in msg + or "Invalid connection state" in msg + or "Invalid cursor state" in msg + or "connection shutdown" in msg + ) + else: + return False + + +dialect = FBDialect_kinterbasdb diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 0000000..cae0168 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,85 @@ +# mssql/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from . import base # noqa +from . import mxodbc # noqa +from . import pymssql # noqa +from . import pyodbc # noqa +from .base import BIGINT +from .base import BINARY +from .base import BIT +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DATETIME2 +from .base import DATETIMEOFFSET +from .base import DECIMAL +from .base import FLOAT +from .base import IMAGE +from .base import INTEGER +from .base import JSON +from .base import MONEY +from .base import NCHAR +from .base import NTEXT +from .base import NUMERIC +from .base import NVARCHAR +from .base import REAL +from .base import ROWVERSION +from .base import SMALLDATETIME +from .base import SMALLINT +from .base import SMALLMONEY +from .base import SQL_VARIANT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import TINYINT +from .base import try_cast +from .base import UNIQUEIDENTIFIER +from .base import VARBINARY +from .base import VARCHAR +from .base import XML + + +base.dialect = dialect = pyodbc.dialect + + +__all__ = ( + "JSON", + "INTEGER", + "BIGINT", + "SMALLINT", + "TINYINT", + "VARCHAR", + "NVARCHAR", + "CHAR", + "NCHAR", + "TEXT", + "NTEXT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "DATETIME", + "DATETIME2", + "DATETIMEOFFSET", + "DATE", + "TIME", + "SMALLDATETIME", + "BINARY", + "VARBINARY", + "BIT", + "REAL", + "IMAGE", + "TIMESTAMP", + "ROWVERSION", + "MONEY", + "SMALLMONEY", + "UNIQUEIDENTIFIER", + "SQL_VARIANT", + "XML", + "dialect", + "try_cast", +) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 0000000..ee6ce87 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,3545 @@ +# mssql/base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" +.. dialect:: mssql + :name: Microsoft SQL Server + :full_support: 2017 + :normal_support: 2012+ + :best_effort: 2005+ + +.. _mssql_external_dialects: + +External Dialects +----------------- + +In addition to the above DBAPI layers with native SQLAlchemy support, there +are third-party dialects for other DBAPI layers that are compatible +with SQL Server. See the "External Dialects" list on the +:ref:`dialect_toplevel` page. + +.. _mssql_identity: + +Auto Increment Behavior / IDENTITY Columns +------------------------------------------ + +SQL Server provides so-called "auto incrementing" behavior using the +``IDENTITY`` construct, which can be placed on any single integer column in a +table. SQLAlchemy considers ``IDENTITY`` within its default "autoincrement" +behavior for an integer primary key column, described at +:paramref:`_schema.Column.autoincrement`. This means that by default, +the first integer primary key column in a :class:`_schema.Table` will be +considered to be the identity column - unless it is associated with a +:class:`.Sequence` - and will generate DDL as such:: + + from sqlalchemy import Table, MetaData, Column, Integer + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True), + Column('x', Integer)) + m.create_all(engine) + +The above example will generate DDL as: + +.. sourcecode:: sql + + CREATE TABLE t ( + id INTEGER NOT NULL IDENTITY, + x INTEGER NULL, + PRIMARY KEY (id) + ) + +For the case where this default generation of ``IDENTITY`` is not desired, +specify ``False`` for the :paramref:`_schema.Column.autoincrement` flag, +on the first integer primary key column:: + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('x', Integer)) + m.create_all(engine) + +To add the ``IDENTITY`` keyword to a non-primary key column, specify +``True`` for the :paramref:`_schema.Column.autoincrement` flag on the desired +:class:`_schema.Column` object, and ensure that +:paramref:`_schema.Column.autoincrement` +is set to ``False`` on any integer primary key column:: + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('x', Integer, autoincrement=True)) + m.create_all(engine) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the start and increment + parameters of an IDENTITY. These replace + the use of the :class:`.Sequence` object in order to specify these values. + +.. deprecated:: 1.4 + + The ``mssql_identity_start`` and ``mssql_identity_increment`` parameters + to :class:`_schema.Column` are deprecated and should we replaced by + an :class:`_schema.Identity` object. Specifying both ways of configuring + an IDENTITY will result in a compile error. + These options are also no longer returned as part of the + ``dialect_options`` key in :meth:`_reflection.Inspector.get_columns`. + Use the information in the ``identity`` key instead. + +.. deprecated:: 1.3 + + The use of :class:`.Sequence` to specify IDENTITY characteristics is + deprecated and will be removed in a future release. Please use + the :class:`_schema.Identity` object parameters + :paramref:`_schema.Identity.start` and + :paramref:`_schema.Identity.increment`. + +.. versionchanged:: 1.4 Removed the ability to use a :class:`.Sequence` + object to modify IDENTITY characteristics. :class:`.Sequence` objects + now only manipulate true T-SQL SEQUENCE types. + +.. note:: + + There can only be one IDENTITY column on the table. When using + ``autoincrement=True`` to enable the IDENTITY keyword, SQLAlchemy does not + guard against multiple columns specifying the option simultaneously. The + SQL Server database will instead reject the ``CREATE TABLE`` statement. + +.. note:: + + An INSERT statement which attempts to provide a value for a column that is + marked with IDENTITY will be rejected by SQL Server. In order for the + value to be accepted, a session-level option "SET IDENTITY_INSERT" must be + enabled. The SQLAlchemy SQL Server dialect will perform this operation + automatically when using a core :class:`_expression.Insert` + construct; if the + execution specifies a value for the IDENTITY column, the "IDENTITY_INSERT" + option will be enabled for the span of that statement's invocation.However, + this scenario is not high performing and should not be relied upon for + normal use. If a table doesn't actually require IDENTITY behavior in its + integer primary key column, the keyword should be disabled when creating + the table by ensuring that ``autoincrement=False`` is set. + +Controlling "Start" and "Increment" +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Specific control over the "start" and "increment" values for +the ``IDENTITY`` generator are provided using the +:paramref:`_schema.Identity.start` and :paramref:`_schema.Identity.increment` +parameters passed to the :class:`_schema.Identity` object:: + + from sqlalchemy import Table, Integer, Column, Identity + + test = Table( + 'test', metadata, + Column( + 'id', + Integer, + primary_key=True, + Identity(start=100, increment=10) + ), + Column('name', String(20)) + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +.. note:: + + The :class:`_schema.Identity` object supports many other parameter in + addition to ``start`` and ``increment``. These are not supported by + SQL Server and will be ignored when generating the CREATE TABLE ddl. + +.. versionchanged:: 1.3.19 The :class:`_schema.Identity` object is + now used to affect the + ``IDENTITY`` generator for a :class:`_schema.Column` under SQL Server. + Previously, the :class:`.Sequence` object was used. As SQL Server now + supports real sequences as a separate construct, :class:`.Sequence` will be + functional in the normal way starting from SQLAlchemy version 1.4. + + +Using IDENTITY with Non-Integer numeric types +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQL Server also allows ``IDENTITY`` to be used with ``NUMERIC`` columns. To +implement this pattern smoothly in SQLAlchemy, the primary datatype of the +column should remain as ``Integer``, however the underlying implementation +type deployed to the SQL Server database can be specified as ``Numeric`` using +:meth:`.TypeEngine.with_variant`:: + + from sqlalchemy import Column + from sqlalchemy import Integer + from sqlalchemy import Numeric + from sqlalchemy import String + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class TestTable(Base): + __tablename__ = "test" + id = Column( + Integer().with_variant(Numeric(10, 0), "mssql"), + primary_key=True, + autoincrement=True, + ) + name = Column(String) + +In the above example, ``Integer().with_variant()`` provides clear usage +information that accurately describes the intent of the code. The general +restriction that ``autoincrement`` only applies to ``Integer`` is established +at the metadata level and not at the per-dialect level. + +When using the above pattern, the primary key identifier that comes back from +the insertion of a row, which is also the value that would be assigned to an +ORM object such as ``TestTable`` above, will be an instance of ``Decimal()`` +and not ``int`` when using SQL Server. The numeric return type of the +:class:`_types.Numeric` type can be changed to return floats by passing False +to :paramref:`_types.Numeric.asdecimal`. To normalize the return type of the +above ``Numeric(10, 0)`` to return Python ints (which also support "long" +integer values in Python 3), use :class:`_types.TypeDecorator` as follows:: + + from sqlalchemy import TypeDecorator + + class NumericAsInteger(TypeDecorator): + '''normalize floating point return values into ints''' + + impl = Numeric(10, 0, asdecimal=False) + cache_ok = True + + def process_result_value(self, value, dialect): + if value is not None: + value = int(value) + return value + + class TestTable(Base): + __tablename__ = "test" + id = Column( + Integer().with_variant(NumericAsInteger, "mssql"), + primary_key=True, + autoincrement=True, + ) + name = Column(String) + + +INSERT behavior +^^^^^^^^^^^^^^^^ + +Handling of the ``IDENTITY`` column at INSERT time involves two key +techniques. The most common is being able to fetch the "last inserted value" +for a given ``IDENTITY`` column, a process which SQLAlchemy performs +implicitly in many cases, most importantly within the ORM. + +The process for fetching this value has several variants: + +* In the vast majority of cases, RETURNING is used in conjunction with INSERT + statements on SQL Server in order to get newly generated primary key values: + + .. sourcecode:: sql + + INSERT INTO t (x) OUTPUT inserted.id VALUES (?) + +* When RETURNING is not available or has been disabled via + ``implicit_returning=False``, either the ``scope_identity()`` function or + the ``@@identity`` variable is used; behavior varies by backend: + + * when using PyODBC, the phrase ``; select scope_identity()`` will be + appended to the end of the INSERT statement; a second result set will be + fetched in order to receive the value. Given a table as:: + + t = Table('t', m, Column('id', Integer, primary_key=True), + Column('x', Integer), + implicit_returning=False) + + an INSERT will look like: + + .. sourcecode:: sql + + INSERT INTO t (x) VALUES (?); select scope_identity() + + * Other dialects such as pymssql will call upon + ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT + statement. If the flag ``use_scope_identity=False`` is passed to + :func:`_sa.create_engine`, + the statement ``SELECT @@identity AS lastrowid`` + is used instead. + +A table that contains an ``IDENTITY`` column will prohibit an INSERT statement +that refers to the identity column explicitly. The SQLAlchemy dialect will +detect when an INSERT construct, created using a core +:func:`_expression.insert` +construct (not a plain string SQL), refers to the identity column, and +in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert +statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the +execution. Given this example:: + + m = MetaData() + t = Table('t', m, Column('id', Integer, primary_key=True), + Column('x', Integer)) + m.create_all(engine) + + with engine.begin() as conn: + conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) + +The above column will be created with IDENTITY, however the INSERT statement +we emit is specifying explicit values. In the echo output we can see +how SQLAlchemy handles this: + +.. sourcecode:: sql + + CREATE TABLE t ( + id INTEGER NOT NULL IDENTITY(1,1), + x INTEGER NULL, + PRIMARY KEY (id) + ) + + COMMIT + SET IDENTITY_INSERT t ON + INSERT INTO t (id, x) VALUES (?, ?) + ((1, 1), (2, 2)) + SET IDENTITY_INSERT t OFF + COMMIT + + + +This is an auxiliary use case suitable for testing and bulk insert scenarios. + +SEQUENCE support +---------------- + +The :class:`.Sequence` object now creates "real" sequences, i.e., +``CREATE SEQUENCE``. To provide compatibility with other dialects, +:class:`.Sequence` defaults to a start value of 1, even though the +T-SQL defaults is -9223372036854775808. + +.. versionadded:: 1.4.0 + +MAX on VARCHAR / NVARCHAR +------------------------- + +SQL Server supports the special string "MAX" within the +:class:`_types.VARCHAR` and :class:`_types.NVARCHAR` datatypes, +to indicate "maximum length possible". The dialect currently handles this as +a length of "None" in the base type, rather than supplying a +dialect-specific version of these types, so that a base type +specified such as ``VARCHAR(None)`` can assume "unlengthed" behavior on +more than one backend without using dialect-specific types. + +To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None:: + + my_table = Table( + 'my_table', metadata, + Column('my_data', VARCHAR(None)), + Column('my_n_data', NVARCHAR(None)) + ) + + +Collation Support +----------------- + +Character collations are supported by the base string types, +specified by the string argument "collation":: + + from sqlalchemy import VARCHAR + Column('login', VARCHAR(32, collation='Latin1_General_CI_AS')) + +When such a column is associated with a :class:`_schema.Table`, the +CREATE TABLE statement for this column will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +LIMIT/OFFSET Support +-------------------- + +MSSQL has added support for LIMIT / OFFSET as of SQL Server 2012, via the +"OFFSET n ROWS" and "FETCH NEXT n ROWS" clauses. SQLAlchemy supports these +syntaxes automatically if SQL Server 2012 or greater is detected. + +.. versionchanged:: 1.4 support added for SQL Server "OFFSET n ROWS" and + "FETCH NEXT n ROWS" syntax. + +For statements that specify only LIMIT and no OFFSET, all versions of SQL +Server support the TOP keyword. This syntax is used for all SQL Server +versions when no OFFSET clause is present. A statement such as:: + + select(some_table).limit(5) + +will render similarly to:: + + SELECT TOP 5 col1, col2.. FROM table + +For versions of SQL Server prior to SQL Server 2012, a statement that uses +LIMIT and OFFSET, or just OFFSET alone, will be rendered using the +``ROW_NUMBER()`` window function. A statement such as:: + + select(some_table).order_by(some_table.c.col3).limit(5).offset(10) + +will render similarly to:: + + SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2, + ROW_NUMBER() OVER (ORDER BY col3) AS + mssql_rn FROM table WHERE t.x = :x_1) AS + anon_1 WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1 + +Note that when using LIMIT and/or OFFSET, whether using the older +or newer SQL Server syntaxes, the statement must have an ORDER BY as well, +else a :class:`.CompileError` is raised. + +.. _mssql_isolation_level: + +Transaction Isolation Level +--------------------------- + +All SQL Server dialects support setting of transaction isolation level +both via a dialect-specific parameter +:paramref:`_sa.create_engine.isolation_level` +accepted by :func:`_sa.create_engine`, +as well as the :paramref:`.Connection.execution_options.isolation_level` +argument as passed to +:meth:`_engine.Connection.execution_options`. +This feature works by issuing the +command ``SET TRANSACTION ISOLATION LEVEL <level>`` for +each new connection. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "mssql+pyodbc://scott:tiger@ms_2008", + isolation_level="REPEATABLE READ" + ) + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) + +Valid values for ``isolation_level`` include: + +* ``AUTOCOMMIT`` - pyodbc / pymssql-specific +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``SNAPSHOT`` - specific to SQL Server + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +.. seealso:: + + :ref:`dbapi_autocommit` + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL`` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +.. _mssql_large_type_deprecation: + +Large Text/Binary Type Deprecation +---------------------------------- + +Per +`SQL Server 2012/2014 Documentation <https://technet.microsoft.com/en-us/library/ms187993.aspx>`_, +the ``NTEXT``, ``TEXT`` and ``IMAGE`` datatypes are to be removed from SQL +Server in a future release. SQLAlchemy normally relates these types to the +:class:`.UnicodeText`, :class:`_expression.TextClause` and +:class:`.LargeBinary` datatypes. + +In order to accommodate this change, a new flag ``deprecate_large_types`` +is added to the dialect, which will be automatically set based on detection +of the server version in use, if not otherwise set by the user. The +behavior of this flag is as follows: + +* When this flag is ``True``, the :class:`.UnicodeText`, + :class:`_expression.TextClause` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NVARCHAR(max)``, ``VARCHAR(max)``, and ``VARBINARY(max)``, + respectively. This is a new behavior as of the addition of this flag. + +* When this flag is ``False``, the :class:`.UnicodeText`, + :class:`_expression.TextClause` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NTEXT``, ``TEXT``, and ``IMAGE``, + respectively. This is the long-standing behavior of these types. + +* The flag begins with the value ``None``, before a database connection is + established. If the dialect is used to render DDL without the flag being + set, it is interpreted the same as ``False``. + +* On first connection, the dialect detects if SQL Server version 2012 or + greater is in use; if the flag is still at ``None``, it sets it to ``True`` + or ``False`` based on whether 2012 or greater is detected. + +* The flag can be set to either ``True`` or ``False`` when the dialect + is created, typically via :func:`_sa.create_engine`:: + + eng = create_engine("mssql+pymssql://user:pass@host/db", + deprecate_large_types=True) + +* Complete control over whether the "old" or "new" types are rendered is + available in all SQLAlchemy versions by using the UPPERCASE type objects + instead: :class:`_types.NVARCHAR`, :class:`_types.VARCHAR`, + :class:`_types.VARBINARY`, :class:`_types.TEXT`, :class:`_mssql.NTEXT`, + :class:`_mssql.IMAGE` + will always remain fixed and always output exactly that + type. + +.. versionadded:: 1.0.0 + +.. _multipart_schema_names: + +Multipart Schema Names +---------------------- + +SQL Server schemas sometimes require multiple parts to their "schema" +qualifier, that is, including the database name and owner name as separate +tokens, such as ``mydatabase.dbo.some_table``. These multipart names can be set +at once using the :paramref:`_schema.Table.schema` argument of +:class:`_schema.Table`:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="mydatabase.dbo" + ) + +When performing operations such as table or component reflection, a schema +argument that contains a dot will be split into separate +"database" and "owner" components in order to correctly query the SQL +Server information schema tables, as these two values are stored separately. +Additionally, when rendering the schema name for DDL or SQL, the two +components will be quoted separately for case sensitive names and other +special characters. Given an argument as below:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="MyDataBase.dbo" + ) + +The above schema would be rendered as ``[MyDataBase].dbo``, and also in +reflection, would be reflected using "dbo" as the owner and "MyDataBase" +as the database name. + +To control how the schema name is broken into database / owner, +specify brackets (which in SQL Server are quoting characters) in the name. +Below, the "owner" will be considered as ``MyDataBase.dbo`` and the +"database" will be None:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="[MyDataBase.dbo]" + ) + +To individually specify both database and owner name with special characters +or embedded dots, use two sets of brackets:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="[MyDataBase.Period].[MyOwner.Dot]" + ) + + +.. versionchanged:: 1.2 the SQL Server dialect now treats brackets as + identifier delimiters splitting the schema into separate database + and owner tokens, to allow dots within either name itself. + +.. _legacy_schema_rendering: + +Legacy Schema Mode +------------------ + +Very old versions of the MSSQL dialect introduced the behavior such that a +schema-qualified table would be auto-aliased when used in a +SELECT statement; given a table:: + + account_table = Table( + 'account', metadata, + Column('id', Integer, primary_key=True), + Column('info', String(100)), + schema="customer_schema" + ) + +this legacy mode of rendering would assume that "customer_schema.account" +would not be accepted by all parts of the SQL statement, as illustrated +below:: + + >>> eng = create_engine("mssql+pymssql://mydsn", legacy_schema_aliasing=True) + >>> print(account_table.select().compile(eng)) + SELECT account_1.id, account_1.info + FROM customer_schema.account AS account_1 + +This mode of behavior is now off by default, as it appears to have served +no purpose; however in the case that legacy applications rely upon it, +it is available using the ``legacy_schema_aliasing`` argument to +:func:`_sa.create_engine` as illustrated above. + +.. versionchanged:: 1.1 the ``legacy_schema_aliasing`` flag introduced + in version 1.0.5 to allow disabling of legacy mode for schemas now + defaults to False. + +.. deprecated:: 1.4 + + The ``legacy_schema_aliasing`` flag is now + deprecated and will be removed in a future release. + +.. _mssql_indexes: + +Clustered Index Support +----------------------- + +The MSSQL dialect supports clustered indexes (and primary keys) via the +``mssql_clustered`` option. This option is available to :class:`.Index`, +:class:`.UniqueConstraint`. and :class:`.PrimaryKeyConstraint`. + +To generate a clustered index:: + + Index("my_index", table.c.x, mssql_clustered=True) + +which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``. + +To generate a clustered primary key use:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=True)) + +which will render the table, for example, as:: + + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, + PRIMARY KEY CLUSTERED (x, y)) + +Similarly, we can generate a clustered unique constraint using:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x"), + UniqueConstraint("y", mssql_clustered=True), + ) + +To explicitly request a non-clustered primary key (for example, when +a separate clustered index is desired), use:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=False)) + +which will render the table, for example, as:: + + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, + PRIMARY KEY NONCLUSTERED (x, y)) + +.. versionchanged:: 1.1 the ``mssql_clustered`` option now defaults + to None, rather than False. ``mssql_clustered=False`` now explicitly + renders the NONCLUSTERED clause, whereas None omits the CLUSTERED + clause entirely, allowing SQL Server defaults to take effect. + + +MSSQL-Specific Index Options +----------------------------- + +In addition to clustering, the MSSQL dialect supports other special options +for :class:`.Index`. + +INCLUDE +^^^^^^^ + +The ``mssql_include`` option renders INCLUDE(colname) for the given string +names:: + + Index("my_index", table.c.x, mssql_include=['y']) + +would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` + +.. _mssql_index_where: + +Filtered Indexes +^^^^^^^^^^^^^^^^ + +The ``mssql_where`` option renders WHERE(condition) for the given string +names:: + + Index("my_index", table.c.x, mssql_where=table.c.x > 10) + +would render the index as ``CREATE INDEX my_index ON table (x) WHERE x > 10``. + +.. versionadded:: 1.3.4 + +Index ordering +^^^^^^^^^^^^^^ + +Index ordering is available via functional expressions, such as:: + + Index("my_index", table.c.x.desc()) + +would render the index as ``CREATE INDEX my_index ON table (x DESC)`` + +.. seealso:: + + :ref:`schema_indexes_functional` + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatible with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always return the database +server version information (in this case SQL2005) and not the +compatibility level information. Because of this, if running under +a backwards compatibility mode SQLAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +Triggers +-------- + +SQLAlchemy by default uses OUTPUT INSERTED to get at newly +generated primary key values via IDENTITY columns or other +server side defaults. MS-SQL does not +allow the usage of OUTPUT INSERTED on tables that have triggers. +To disable the usage of OUTPUT INSERTED on a per-table basis, +specify ``implicit_returning=False`` for each :class:`_schema.Table` +which has triggers:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + # ..., + implicit_returning=False + ) + +Declarative form:: + + class MyClass(Base): + # ... + __table_args__ = {'implicit_returning':False} + + +This option can also be specified engine-wide using the +``implicit_returning=False`` argument on :func:`_sa.create_engine`. + +.. _mssql_rowcount_versioning: + +Rowcount Support / ORM Versioning +--------------------------------- + +The SQL Server drivers may have limited ability to return the number +of rows updated from an UPDATE or DELETE statement. + +As of this writing, the PyODBC driver is not able to return a rowcount when +OUTPUT INSERTED is used. This impacts the SQLAlchemy ORM's versioning feature +in many cases where server-side value generators are in use in that while the +versioning operations can succeed, the ORM cannot always check that an UPDATE +or DELETE statement matched the number of rows expected, which is how it +verifies that the version identifier matched. When this condition occurs, a +warning will be emitted but the operation will proceed. + +The use of OUTPUT INSERTED can be disabled by setting the +:paramref:`_schema.Table.implicit_returning` flag to ``False`` on a particular +:class:`_schema.Table`, which in declarative looks like:: + + class MyTable(Base): + __tablename__ = 'mytable' + id = Column(Integer, primary_key=True) + stuff = Column(String(10)) + timestamp = Column(TIMESTAMP(), default=text('DEFAULT')) + __mapper_args__ = { + 'version_id_col': timestamp, + 'version_id_generator': False, + } + __table_args__ = { + 'implicit_returning': False + } + +Enabling Snapshot Isolation +--------------------------- + +SQL Server has a default transaction +isolation mode that locks entire tables, and causes even mildly concurrent +applications to have long held locks and frequent deadlocks. +Enabling snapshot isolation for the database as a whole is recommended +for modern levels of concurrency support. This is accomplished via the +following ALTER DATABASE commands executed at the SQL prompt:: + + ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON + + ALTER DATABASE MyDatabase SET READ_COMMITTED_SNAPSHOT ON + +Background on SQL Server snapshot isolation is available at +https://msdn.microsoft.com/en-us/library/ms175095.aspx. + +""" # noqa + +import codecs +import datetime +import operator +import re + +from . import information_schema as ischema +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from ... import exc +from ... import Identity +from ... import schema as sa_schema +from ... import Sequence +from ... import sql +from ... import text +from ... import types as sqltypes +from ... import util +from ...engine import cursor as _cursor +from ...engine import default +from ...engine import reflection +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import expression +from ...sql import func +from ...sql import quoted_name +from ...sql import roles +from ...sql import util as sql_util +from ...types import BIGINT +from ...types import BINARY +from ...types import CHAR +from ...types import DATE +from ...types import DATETIME +from ...types import DECIMAL +from ...types import FLOAT +from ...types import INTEGER +from ...types import NCHAR +from ...types import NUMERIC +from ...types import NVARCHAR +from ...types import SMALLINT +from ...types import TEXT +from ...types import VARCHAR +from ...util import compat +from ...util import update_wrapper +from ...util.langhelpers import public_factory + + +# https://sqlserverbuilds.blogspot.com/ +MS_2017_VERSION = (14,) +MS_2016_VERSION = (13,) +MS_2014_VERSION = (12,) +MS_2012_VERSION = (11,) +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = set( + [ + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", + ] +) + + +class REAL(sqltypes.REAL): + __visit_name__ = "REAL" + + def __init__(self, **kw): + # REAL is a synonym for FLOAT(24) on SQL server. + # it is only accepted as the word "REAL" in DDL, the numeric + # precision value is not allowed to be present + kw.setdefault("precision", 24) + super(REAL, self).__init__(**kw) + + +class TINYINT(sqltypes.Integer): + __visit_name__ = "TINYINT" + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, util.string_types): + m = self._reg.match(value) + if not m: + raise ValueError( + "could not parse %r as a date value" % (value,) + ) + return datetime.date(*[int(x or 0) for x in m.groups()]) + else: + return value + + return process + + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super(TIME, self).__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine( + self.__zero_date, value.time() + ) + elif isinstance(value, datetime.time): + """issue #5339 + per: https://github.com/mkleehammer/pyodbc/wiki/Tips-and-Tricks-by-Database-Platform#time-columns + pass TIME value as string + """ # noqa + value = str(value) + return value + + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?") + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, util.string_types): + m = self._reg.match(value) + if not m: + raise ValueError( + "could not parse %r as a time value" % (value,) + ) + return datetime.time(*[int(x or 0) for x in m.groups()]) + else: + return value + + return process + + +_MSTime = TIME + + +class _BASETIMEIMPL(TIME): + __visit_name__ = "_BASETIMEIMPL" + + +class _DateTimeBase(object): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + + return process + + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "SMALLDATETIME" + + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "DATETIME2" + + def __init__(self, precision=None, **kw): + super(DATETIME2, self).__init__(**kw) + self.precision = precision + + +class DATETIMEOFFSET(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "DATETIMEOFFSET" + + def __init__(self, precision=None, **kw): + super(DATETIMEOFFSET, self).__init__(**kw) + self.precision = precision + + +class _UnicodeLiteral(object): + def literal_processor(self, dialect): + def process(value): + + value = value.replace("'", "''") + + if dialect.identifier_preparer._double_percents: + value = value.replace("%", "%%") + + return "N'%s'" % value + + return process + + +class _MSUnicode(_UnicodeLiteral, sqltypes.Unicode): + pass + + +class _MSUnicodeText(_UnicodeLiteral, sqltypes.UnicodeText): + pass + + +class TIMESTAMP(sqltypes._Binary): + """Implement the SQL Server TIMESTAMP type. + + Note this is **completely different** than the SQL Standard + TIMESTAMP type, which is not supported by SQL Server. It + is a read-only datatype that does not support INSERT of values. + + .. versionadded:: 1.2 + + .. seealso:: + + :class:`_mssql.ROWVERSION` + + """ + + __visit_name__ = "TIMESTAMP" + + # expected by _Binary to be present + length = None + + def __init__(self, convert_int=False): + """Construct a TIMESTAMP or ROWVERSION type. + + :param convert_int: if True, binary integer values will + be converted to integers on read. + + .. versionadded:: 1.2 + + """ + self.convert_int = convert_int + + def result_processor(self, dialect, coltype): + super_ = super(TIMESTAMP, self).result_processor(dialect, coltype) + if self.convert_int: + + def process(value): + value = super_(value) + if value is not None: + # https://stackoverflow.com/a/30403242/34549 + value = int(codecs.encode(value, "hex"), 16) + return value + + return process + else: + return super_ + + +class ROWVERSION(TIMESTAMP): + """Implement the SQL Server ROWVERSION type. + + The ROWVERSION datatype is a SQL Server synonym for the TIMESTAMP + datatype, however current SQL Server documentation suggests using + ROWVERSION for new datatypes going forward. + + The ROWVERSION datatype does **not** reflect (e.g. introspect) from the + database as itself; the returned datatype will be + :class:`_mssql.TIMESTAMP`. + + This is a read-only datatype that does not support INSERT of values. + + .. versionadded:: 1.2 + + .. seealso:: + + :class:`_mssql.TIMESTAMP` + + """ + + __visit_name__ = "ROWVERSION" + + +class NTEXT(sqltypes.UnicodeText): + + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = "NTEXT" + + +class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): + """The MSSQL VARBINARY type. + + This type adds additional features to the core :class:`_types.VARBINARY` + type, including "deprecate_large_types" mode where + either ``VARBINARY(max)`` or IMAGE is rendered, as well as the SQL + Server ``FILESTREAM`` option. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :ref:`mssql_large_type_deprecation` + + """ + + __visit_name__ = "VARBINARY" + + def __init__(self, length=None, filestream=False): + """ + Construct a VARBINARY type. + + :param length: optional, a length for the column for use in + DDL statements, for those binary types that accept a length, + such as the MySQL BLOB type. + + :param filestream=False: if True, renders the ``FILESTREAM`` keyword + in the table definition. In this case ``length`` must be ``None`` + or ``'max'``. + + .. versionadded:: 1.4.31 + + """ + + self.filestream = filestream + if self.filestream and length not in (None, "max"): + raise ValueError( + "length must be None or 'max' when setting filestream" + ) + super(VARBINARY, self).__init__(length=length) + + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = "IMAGE" + + +class XML(sqltypes.Text): + """MSSQL XML type. + + This is a placeholder type for reflection purposes that does not include + any Python-side datatype support. It also does not currently support + additional arguments, such as "CONTENT", "DOCUMENT", + "xml_schema_collection". + + .. versionadded:: 1.1.11 + + """ + + __visit_name__ = "XML" + + +class BIT(sqltypes.Boolean): + """MSSQL BIT type. + + Both pyodbc and pymssql return values from BIT columns as + Python <class 'bool'> so just subclass Boolean. + + """ + + __visit_name__ = "BIT" + + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = "SMALLMONEY" + + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = "SQL_VARIANT" + + +class TryCast(sql.elements.Cast): + """Represent a SQL Server TRY_CAST expression.""" + + __visit_name__ = "try_cast" + + stringify_dialect = "mssql" + inherit_cache = True + + def __init__(self, *arg, **kw): + """Create a TRY_CAST expression. + + :class:`.TryCast` is a subclass of SQLAlchemy's :class:`.Cast` + construct, and works in the same way, except that the SQL expression + rendered is "TRY_CAST" rather than "CAST":: + + from sqlalchemy import select + from sqlalchemy import Numeric + from sqlalchemy.dialects.mssql import try_cast + + stmt = select( + try_cast(product_table.c.unit_price, Numeric(10, 4)) + ) + + The above would render:: + + SELECT TRY_CAST (product_table.unit_price AS NUMERIC(10, 4)) + FROM product_table + + .. versionadded:: 1.3.7 + + """ + super(TryCast, self).__init__(*arg, **kw) + + +try_cast = public_factory(TryCast, ".dialects.mssql.try_cast") + +# old names. +MSDateTime = _MSDateTime +MSDate = _MSDate +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +ischema_names = { + "int": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "tinyint": TINYINT, + "varchar": VARCHAR, + "nvarchar": NVARCHAR, + "char": CHAR, + "nchar": NCHAR, + "text": TEXT, + "ntext": NTEXT, + "decimal": DECIMAL, + "numeric": NUMERIC, + "float": FLOAT, + "datetime": DATETIME, + "datetime2": DATETIME2, + "datetimeoffset": DATETIMEOFFSET, + "date": DATE, + "time": TIME, + "smalldatetime": SMALLDATETIME, + "binary": BINARY, + "varbinary": VARBINARY, + "bit": BIT, + "real": REAL, + "image": IMAGE, + "xml": XML, + "timestamp": TIMESTAMP, + "money": MONEY, + "smallmoney": SMALLMONEY, + "uniqueidentifier": UNIQUEIDENTIFIER, + "sql_variant": SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_, length=None): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, "collation", None): + collation = "COLLATE %s" % type_.collation + else: + collation = None + + if not length: + length = type_.length + + if length: + spec = spec + "(%s)" % length + + return " ".join([c for c in (spec, collation) if c is not None]) + + def visit_FLOAT(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {"precision": precision} + + def visit_TINYINT(self, type_, **kw): + return "TINYINT" + + def visit_TIME(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP" + + def visit_ROWVERSION(self, type_, **kw): + return "ROWVERSION" + + def visit_datetime(self, type_, **kw): + if type_.timezone: + return self.visit_DATETIMEOFFSET(type_, **kw) + else: + return self.visit_DATETIME(type_, **kw) + + def visit_DATETIMEOFFSET(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_DATETIME2(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_, **kw): + return "SMALLDATETIME" + + def visit_unicode(self, type_, **kw): + return self.visit_NVARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARCHAR(type_, **kw) + else: + return self.visit_TEXT(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_NVARCHAR(type_, **kw) + else: + return self.visit_NTEXT(type_, **kw) + + def visit_NTEXT(self, type_, **kw): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_, **kw): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_, **kw): + return self._extend("VARCHAR", type_, length=type_.length or "max") + + def visit_CHAR(self, type_, **kw): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_, **kw): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_, **kw): + return self._extend("NVARCHAR", type_, length=type_.length or "max") + + def visit_date(self, type_, **kw): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_, **kw) + else: + return self.visit_DATE(type_, **kw) + + def visit__BASETIMEIMPL(self, type_, **kw): + return self.visit_time(type_, **kw) + + def visit_time(self, type_, **kw): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_, **kw) + else: + return self.visit_TIME(type_, **kw) + + def visit_large_binary(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARBINARY(type_, **kw) + else: + return self.visit_IMAGE(type_, **kw) + + def visit_IMAGE(self, type_, **kw): + return "IMAGE" + + def visit_XML(self, type_, **kw): + return "XML" + + def visit_VARBINARY(self, type_, **kw): + text = self._extend("VARBINARY", type_, length=type_.length or "max") + if getattr(type_, "filestream", False): + text += " FILESTREAM" + return text + + def visit_boolean(self, type_, **kw): + return self.visit_BIT(type_) + + def visit_BIT(self, type_, **kw): + return "BIT" + + def visit_JSON(self, type_, **kw): + # this is a bit of a break with SQLAlchemy's convention of + # "UPPERCASE name goes to UPPERCASE type name with no modification" + return self._extend("NVARCHAR", type_, length="max") + + def visit_MONEY(self, type_, **kw): + return "MONEY" + + def visit_SMALLMONEY(self, type_, **kw): + return "SMALLMONEY" + + def visit_UNIQUEIDENTIFIER(self, type_, **kw): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_, **kw): + return "SQL_VARIANT" + + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _lastrowid = None + _rowcount = None + + def _opt_encode(self, statement): + + if not self.dialect.supports_unicode_statements: + encoded = self.dialect._encoder(statement)[0] + else: + encoded = statement + + if self.compiled and self.compiled.schema_translate_map: + + rst = self.compiled.preparer._render_schema_translates + encoded = rst(encoded, self.compiled.schema_translate_map) + + return encoded + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + tbl = self.compiled.compile_state.dml_table + id_column = tbl._autoincrement_column + insert_has_identity = (id_column is not None) and ( + not isinstance(id_column.default, Sequence) + ) + + if insert_has_identity: + compile_state = self.compiled.dml_compile_state + self._enable_identity_insert = ( + id_column.key in self.compiled_parameters[0] + ) or ( + compile_state._dict_parameters + and (id_column.key in compile_state._insert_col_keys) + ) + + else: + self._enable_identity_insert = False + + self._select_lastrowid = ( + not self.compiled.inline + and insert_has_identity + and not self.compiled.returning + and not self._enable_identity_insert + and not self.executemany + ) + + if self._enable_identity_insert: + self.root_connection._cursor_execute( + self.cursor, + self._opt_encode( + "SET IDENTITY_INSERT %s ON" + % self.identifier_preparer.format_table(tbl) + ), + (), + self, + ) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + conn = self.root_connection + + if self.isinsert or self.isupdate or self.isdelete: + self._rowcount = self.cursor.rowcount + + if self._select_lastrowid: + if self.dialect.use_scope_identity: + conn._cursor_execute( + self.cursor, + "SELECT scope_identity() AS lastrowid", + (), + self, + ) + else: + conn._cursor_execute( + self.cursor, "SELECT @@identity AS lastrowid", (), self + ) + # fetchall() ensures the cursor is consumed without closing it + row = self.cursor.fetchall()[0] + self._lastrowid = int(row[0]) + + elif ( + self.isinsert or self.isupdate or self.isdelete + ) and self.compiled.returning: + self.cursor_fetch_strategy = ( + _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + self.cursor.description, + self.cursor.fetchall(), + ) + ) + + if self._enable_identity_insert: + conn._cursor_execute( + self.cursor, + self._opt_encode( + "SET IDENTITY_INSERT %s OFF" + % self.identifier_preparer.format_table( + self.compiled.compile_state.dml_table + ) + ), + (), + self, + ) + + def get_lastrowid(self): + return self._lastrowid + + @property + def rowcount(self): + if self._rowcount is not None: + return self._rowcount + else: + return self.cursor.rowcount + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute( + self._opt_encode( + "SET IDENTITY_INSERT %s OFF" + % self.identifier_preparer.format_table( + self.compiled.compile_state.dml_table + ) + ) + ) + except Exception: + pass + + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "SELECT NEXT VALUE FOR %s" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + def get_insert_default(self, column): + if ( + isinstance(column, sa_schema.Column) + and column is column.table._autoincrement_column + and isinstance(column.default, sa_schema.Sequence) + and column.default.optional + ): + return None + return super(MSExecutionContext, self).get_insert_default(column) + + +class MSSQLCompiler(compiler.SQLCompiler): + returning_precedes_values = True + + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + "doy": "dayofyear", + "dow": "weekday", + "milliseconds": "millisecond", + "microseconds": "microsecond", + }, + ) + + def __init__(self, *args, **kwargs): + self.tablealiases = {} + super(MSSQLCompiler, self).__init__(*args, **kwargs) + + def _with_legacy_schema_aliasing(fn): + def decorate(self, *arg, **kw): + if self.dialect.legacy_schema_aliasing: + return fn(self, *arg, **kw) + else: + super_ = getattr(super(MSSQLCompiler, self), fn.__name__) + return super_(*arg, **kw) + + return decorate + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_concat_op_binary(self, binary, operator, **kw): + return "%s + %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def visit_match_op_binary(self, binary, operator, **kw): + return "CONTAINS (%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def get_select_precolumns(self, select, **kw): + """MS-SQL puts TOP, it's version of LIMIT here""" + + s = super(MSSQLCompiler, self).get_select_precolumns(select, **kw) + + if select._has_row_limiting_clause and self._use_top(select): + # ODBC drivers and possibly others + # don't support bind params in the SELECT clause on SQL Server. + # so have to use literal here. + kw["literal_execute"] = True + s += "TOP %s " % self.process( + self._get_limit_or_fetch(select), **kw + ) + if select._fetch_clause is not None: + if select._fetch_clause_options["percent"]: + s += "PERCENT " + if select._fetch_clause_options["with_ties"]: + s += "WITH TIES " + + return s + + def get_from_hint_text(self, table, text): + return text + + def get_crud_hint_text(self, table, text): + return text + + def _get_limit_or_fetch(self, select): + if select._fetch_clause is None: + return select._limit_clause + else: + return select._fetch_clause + + def _use_top(self, select): + return (select._offset_clause is None) and ( + select._simple_int_clause(select._limit_clause) + or ( + # limit can use TOP with is by itself. fetch only uses TOP + # when it needs to because of PERCENT and/or WITH TIES + select._simple_int_clause(select._fetch_clause) + and ( + select._fetch_clause_options["percent"] + or select._fetch_clause_options["with_ties"] + ) + ) + ) + + def fetch_clause(self, cs, **kwargs): + return "" + + def limit_clause(self, cs, **kwargs): + return "" + + def _check_can_use_fetch_limit(self, select): + # to use ROW_NUMBER(), an ORDER BY is required. + # OFFSET are FETCH are options of the ORDER BY clause + if not select._order_by_clause.clauses: + raise exc.CompileError( + "MSSQL requires an order_by when " + "using an OFFSET or a non-simple " + "LIMIT clause" + ) + + if select._fetch_clause_options is not None and ( + select._fetch_clause_options["percent"] + or select._fetch_clause_options["with_ties"] + ): + raise exc.CompileError( + "MSSQL needs TOP to use PERCENT and/or WITH TIES. " + "Only simple fetch without offset can be used." + ) + + def _row_limit_clause(self, select, **kw): + """MSSQL 2012 supports OFFSET/FETCH operators + Use it instead subquery with row_number + + """ + + if self.dialect._supports_offset_fetch and not self._use_top(select): + self._check_can_use_fetch_limit(select) + + text = "" + + if select._offset_clause is not None: + offset_str = self.process(select._offset_clause, **kw) + else: + offset_str = "0" + text += "\n OFFSET %s ROWS" % offset_str + + limit = self._get_limit_or_fetch(select) + + if limit is not None: + text += "\n FETCH FIRST %s ROWS ONLY" % self.process( + limit, **kw + ) + return text + else: + return "" + + def visit_try_cast(self, element, **kw): + return "TRY_CAST (%s AS %s)" % ( + self.process(element.clause, **kw), + self.process(element.typeclause, **kw), + ) + + def translate_select_structure(self, select_stmt, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + MSSQL 2012 and above are excluded + + """ + select = select_stmt + + if ( + select._has_row_limiting_clause + and not self.dialect._supports_offset_fetch + and not self._use_top(select) + and not getattr(select, "_mssql_visit", None) + ): + self._check_can_use_fetch_limit(select) + + _order_by_clauses = [ + sql_util.unwrap_label_reference(elem) + for elem in select._order_by_clause.clauses + ] + + limit_clause = self._get_limit_or_fetch(select) + offset_clause = select._offset_clause + + select = select._generate() + select._mssql_visit = True + select = ( + select.add_columns( + sql.func.ROW_NUMBER() + .over(order_by=_order_by_clauses) + .label("mssql_rn") + ) + .order_by(None) + .alias() + ) + + mssql_rn = sql.column("mssql_rn") + limitselect = sql.select( + *[c for c in select.c if c.key != "mssql_rn"] + ) + if offset_clause is not None: + limitselect = limitselect.where(mssql_rn > offset_clause) + if limit_clause is not None: + limitselect = limitselect.where( + mssql_rn <= (limit_clause + offset_clause) + ) + else: + limitselect = limitselect.where(mssql_rn <= (limit_clause)) + return limitselect + else: + return select + + @_with_legacy_schema_aliasing + def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs): + if mssql_aliased is table or iscrud: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=table, **kwargs) + else: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + @_with_legacy_schema_aliasing + def visit_alias(self, alias, **kw): + # translate for schema-qualified table aliases + kw["mssql_aliased"] = alias.element + return super(MSSQLCompiler, self).visit_alias(alias, **kw) + + @_with_legacy_schema_aliasing + def visit_column(self, column, add_to_result_map=None, **kw): + if ( + column.table is not None + and (not self.isupdate and not self.isdelete) + or self.is_subquery() + ): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = elements._corresponding_column_or_error(t, column) + if add_to_result_map is not None: + add_to_result_map( + column.name, + column.name, + (column, column.name, column.key), + column.type, + ) + + return super(MSSQLCompiler, self).visit_column(converted, **kw) + + return super(MSSQLCompiler, self).visit_column( + column, add_to_result_map=add_to_result_map, **kw + ) + + def _schema_aliased_table(self, table): + if getattr(table, "schema", None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw)) + + def visit_savepoint(self, savepoint_stmt): + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if ( + isinstance(binary.left, expression.BindParameter) + and binary.operator == operator.eq + and not isinstance(binary.right, expression.BindParameter) + ): + return self.process( + expression.BinaryExpression( + binary.right, binary.left, binary.operator + ), + **kwargs + ) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + + def returning_clause(self, stmt, returning_cols): + # SQL server returning clause requires that the columns refer to + # the virtual table names "inserted" or "deleted". Here, we make + # a simple alias of our table with that name, and then adapt the + # columns we have from the list of RETURNING columns to that new name + # so that they render as "inserted.<colname>" / "deleted.<colname>". + + if self.isinsert or self.isupdate: + target = stmt.table.alias("inserted") + else: + target = stmt.table.alias("deleted") + + adapter = sql_util.ClauseAdapter(target) + + # adapter.traverse() takes a column from our target table and returns + # the one that is linked to the "inserted" / "deleted" tables. So in + # order to retrieve these values back from the result (e.g. like + # row[column]), tell the compiler to also add the original unadapted + # column to the result map. Before #4877, these were (unknowingly) + # falling back using string name matching in the result set which + # necessarily used an expensive KeyError in order to match. + + columns = [ + self._label_returning_column( + stmt, + adapter.traverse(c), + {"result_map_targets": (c,)}, + ) + for c in expression._select_iterables(returning_cols) + ] + + return "OUTPUT " + ", ".join(columns) + + def get_cte_preamble(self, recursive): + # SQL Server finds it too inconvenient to accept + # an entirely optional, SQL standard specified, + # "RECURSIVE" word with their "WITH", + # so here we go + return "WITH" + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(MSSQLCompiler, self).label_select_column( + select, column, asfrom + ) + + def for_update_clause(self, select, **kw): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which + # SQLAlchemy doesn't use + return "" + + def order_by_clause(self, select, **kw): + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if ( + self.is_subquery() + and not select._limit + and ( + select._offset is None + or not self.dialect._supports_offset_fetch + ) + ): + # avoid processing the order by clause if we won't end up + # using it, because we don't want all the bind params tacked + # onto the positional list if that is what the dbapi requires + return "" + + order_by = self.process(select._order_by_clause, **kw) + + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the UPDATE..FROM clause specific to MSSQL. + + In MSSQL, if the UPDATE statement involves an alias of the table to + be updated, then the table itself must be added to the FROM list as + well. Otherwise, it is optional. Here, we add it regardless. + + """ + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def delete_table_clause(self, delete_stmt, from_table, extra_froms): + """If we have extra froms make sure we render any alias as hint.""" + ashint = False + if extra_froms: + ashint = True + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, ashint=ashint + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. FROM clause specific to MSSQL. + + Yes, it has the FROM keyword twice. + + """ + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def visit_empty_set_expr(self, type_): + return "SELECT 1 WHERE 1!=1" + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "NOT EXISTS (SELECT %s INTERSECT SELECT %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "EXISTS (SELECT %s INTERSECT SELECT %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def _render_json_extract_from_binary(self, binary, operator, **kw): + # note we are intentionally calling upon the process() calls in the + # order in which they appear in the SQL String as this is used + # by positional parameter rendering + + if binary.type._type_affinity is sqltypes.JSON: + return "JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + # as with other dialects, start with an explicit test for NULL + case_expression = "CASE JSON_VALUE(%s, %s) WHEN NULL THEN NULL" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + if binary.type._type_affinity is sqltypes.Integer: + type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS INTEGER)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + elif binary.type._type_affinity is sqltypes.Numeric: + type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + "FLOAT" + if isinstance(binary.type, sqltypes.Float) + else "NUMERIC(%s, %s)" + % (binary.type.precision, binary.type.scale), + ) + elif binary.type._type_affinity is sqltypes.Boolean: + # the NULL handling is particularly weird with boolean, so + # explicitly return numeric (BIT) constants + type_expression = ( + "WHEN 'true' THEN 1 WHEN 'false' THEN 0 ELSE NULL" + ) + elif binary.type._type_affinity is sqltypes.String: + # TODO: does this comment (from mysql) apply to here, too? + # this fails with a JSON value that's a four byte unicode + # string. SQLite has the same problem at the moment + type_expression = "ELSE JSON_VALUE(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + else: + # other affinity....this is not expected right now + type_expression = "ELSE JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + return case_expression + " " + type_expression + " END" + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_sequence(self, seq, **kw): + return "NEXT VALUE FOR %s" % self.preparer.format_sequence(seq) + + +class MSSQLStrictCompiler(MSSQLCompiler): + + """A subclass of MSSQLCompiler which disables the usage of bind + parameters where not allowed natively by MS-SQL. + + A dialect may use this compiler on a platform where native + binds are used. + + """ + + ansi_bind_rules = True + + def visit_in_op_binary(self, binary, operator, **kw): + kw["literal_execute"] = True + return "%s IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_not_in_op_binary(self, binary, operator, **kw): + kw["literal_execute"] = True + return "%s NOT IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def render_literal_value(self, value, type_): + """ + For date and datetime values, convert to a string + format acceptable to MSSQL. That seems to be the + so-called ODBC canonical date format which looks + like this: + + yyyy-mm-dd hh:mi:ss.mmm(24h) + + For other data types, call the base class implementation. + """ + # datetime and date are both subclasses of datetime.date + if issubclass(type(value), datetime.date): + # SQL Server wants single quotes around the date string. + return "'" + str(value) + "'" + else: + return super(MSSQLStrictCompiler, self).render_literal_value( + value, type_ + ) + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + + # type is not accepted in a computed column + if column.computed is not None: + colspec += " " + self.process(column.computed) + else: + colspec += " " + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + + if column.nullable is not None: + if ( + not column.nullable + or column.primary_key + or isinstance(column.default, sa_schema.Sequence) + or column.autoincrement is True + or column.identity + ): + colspec += " NOT NULL" + elif column.computed is None: + # don't specify "NULL" for computed columns + colspec += " NULL" + + if column.table is None: + raise exc.CompileError( + "mssql requires Table-bound columns " + "in order to generate DDL" + ) + + d_opt = column.dialect_options["mssql"] + start = d_opt["identity_start"] + increment = d_opt["identity_increment"] + if start is not None or increment is not None: + if column.identity: + raise exc.CompileError( + "Cannot specify options 'mssql_identity_start' and/or " + "'mssql_identity_increment' while also using the " + "'Identity' construct." + ) + util.warn_deprecated( + "The dialect options 'mssql_identity_start' and " + "'mssql_identity_increment' are deprecated. " + "Use the 'Identity' object instead.", + "1.4", + ) + + if column.identity: + colspec += self.process(column.identity, **kwargs) + elif ( + column is column.table._autoincrement_column + or column.autoincrement is True + ) and ( + not isinstance(column.default, Sequence) or column.default.optional + ): + colspec += self.process(Identity(start=start, increment=increment)) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_create_index(self, create, include_schema=False): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + + # handle clustering option + clustered = index.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table(index.table), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + + # handle other included columns + if index.dialect_options["mssql"]["include"]: + inclusions = [ + index.table.c[col] + if isinstance(col, util.string_types) + else col + for col in index.dialect_options["mssql"]["include"] + ] + + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) + + whereclause = index.dialect_options["mssql"]["where"] + + if whereclause is not None: + whereclause = coercions.expect( + roles.DDLExpressionRole, whereclause + ) + + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def visit_drop_index(self, drop): + return "\nDROP INDEX %s ON %s" % ( + self._prepared_index_name(drop.element, include_schema=False), + self.preparer.format_table(drop.element.table), + ) + + def visit_primary_key_constraint(self, constraint): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) + text += "PRIMARY KEY " + + clustered = constraint.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_unique_constraint(self, constraint): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "UNIQUE " + + clustered = constraint.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_computed_column(self, generated): + text = "AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + # explicitly check for True|False since None means server default + if generated.persisted is True: + text += " PERSISTED" + return text + + def visit_create_sequence(self, create, **kw): + prefix = None + if create.element.data_type is not None: + data_type = create.element.data_type + prefix = " AS %s" % self.type_compiler.process(data_type) + return super(MSDDLCompiler, self).visit_create_sequence( + create, prefix=prefix, **kw + ) + + def visit_identity_column(self, identity, **kw): + text = " IDENTITY" + if identity.start is not None or identity.increment is not None: + start = 1 if identity.start is None else identity.start + increment = 1 if identity.increment is None else identity.increment + text += "(%s,%s)" % (start, increment) + return text + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super(MSIdentifierPreparer, self).__init__( + dialect, + initial_quote="[", + final_quote="]", + quote_case_sensitive_collations=False, + ) + + def _escape_identifier(self, value): + return value.replace("]", "]]") + + def _unescape_identifier(self, value): + return value.replace("]]", "]") + + def quote_schema(self, schema, force=None): + """Prepare a quoted table and schema name.""" + + # need to re-implement the deprecation warning entirely + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote_schema.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + version="1.3", + ) + + dbname, owner = _schema_elements(schema) + if dbname: + result = "%s.%s" % (self.quote(dbname), self.quote(owner)) + elif owner: + result = self.quote(owner) + else: + result = "" + return result + + +def _db_plus_owner_listing(fn): + def wrap(dialect, connection, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + dbname, + owner, + schema, + **kw + ) + + return update_wrapper(wrap, fn) + + +def _db_plus_owner(fn): + def wrap(dialect, connection, tablename, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + tablename, + dbname, + owner, + schema, + **kw + ) + + return update_wrapper(wrap, fn) + + +def _switch_db(dbname, connection, fn, *arg, **kw): + if dbname: + current_db = connection.exec_driver_sql("select db_name()").scalar() + if current_db != dbname: + connection.exec_driver_sql( + "use %s" % connection.dialect.identifier_preparer.quote(dbname) + ) + try: + return fn(*arg, **kw) + finally: + if dbname and current_db != dbname: + connection.exec_driver_sql( + "use %s" + % connection.dialect.identifier_preparer.quote(current_db) + ) + + +def _owner_plus_db(dialect, schema): + if not schema: + return None, dialect.default_schema_name + elif "." in schema: + return _schema_elements(schema) + else: + return None, schema + + +_memoized_schema = util.LRUCache() + + +def _schema_elements(schema): + if isinstance(schema, quoted_name) and schema.quote: + return None, schema + + if schema in _memoized_schema: + return _memoized_schema[schema] + + # tests for this function are in: + # test/dialect/mssql/test_reflection.py -> + # OwnerPlusDBTest.test_owner_database_pairs + # test/dialect/mssql/test_compiler.py -> test_force_schema_* + # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_* + # + + if schema.startswith("__[SCHEMA_"): + return None, schema + + push = [] + symbol = "" + bracket = False + has_brackets = False + for token in re.split(r"(\[|\]|\.)", schema): + if not token: + continue + if token == "[": + bracket = True + has_brackets = True + elif token == "]": + bracket = False + elif not bracket and token == ".": + if has_brackets: + push.append("[%s]" % symbol) + else: + push.append(symbol) + symbol = "" + has_brackets = False + else: + symbol += token + if symbol: + push.append(symbol) + if len(push) > 1: + dbname, owner = ".".join(push[0:-1]), push[-1] + + # test for internal brackets + if re.match(r".*\].*\[.*", dbname[1:-1]): + dbname = quoted_name(dbname, quote=False) + else: + dbname = dbname.lstrip("[").rstrip("]") + + elif len(push): + dbname, owner = None, push[0] + else: + dbname, owner = None, None + + _memoized_schema[schema] = dbname, owner + return dbname, owner + + +class MSDialect(default.DefaultDialect): + # will assume it's at least mssql2005 + name = "mssql" + supports_statement_cache = True + supports_default_values = True + supports_empty_insert = False + execution_ctx_cls = MSExecutionContext + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + + implicit_returning = True + full_returning = True + + colspecs = { + sqltypes.DateTime: _MSDateTime, + sqltypes.Date: _MSDate, + sqltypes.JSON: JSON, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, + sqltypes.Time: _BASETIMEIMPL, + sqltypes.Unicode: _MSUnicode, + sqltypes.UnicodeText: _MSUnicodeText, + DATETIMEOFFSET: DATETIMEOFFSET, + DATETIME2: DATETIME2, + SMALLDATETIME: SMALLDATETIME, + DATETIME: DATETIME, + } + + engine_config_types = default.DefaultDialect.engine_config_types.union( + {"legacy_schema_aliasing": util.asbool} + ) + + ischema_names = ischema_names + + supports_sequences = True + sequences_optional = True + # T-SQL's actual default is -9223372036854775808 + default_sequence_base = 1 + + supports_native_boolean = False + non_native_boolean_check_constraint = False + supports_unicode_binds = True + postfetch_lastrowid = True + _supports_offset_fetch = False + _supports_nvarchar_max = False + + legacy_schema_aliasing = False + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler = MSTypeCompiler + preparer = MSIdentifierPreparer + + construct_arguments = [ + (sa_schema.PrimaryKeyConstraint, {"clustered": None}), + (sa_schema.UniqueConstraint, {"clustered": None}), + (sa_schema.Index, {"clustered": None, "include": None, "where": None}), + ( + sa_schema.Column, + {"identity_start": None, "identity_increment": None}, + ), + ] + + def __init__( + self, + query_timeout=None, + use_scope_identity=True, + schema_name="dbo", + isolation_level=None, + deprecate_large_types=None, + json_serializer=None, + json_deserializer=None, + legacy_schema_aliasing=None, + ignore_no_transaction_on_rollback=False, + **opts + ): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.deprecate_large_types = deprecate_large_types + self.ignore_no_transaction_on_rollback = ( + ignore_no_transaction_on_rollback + ) + + if legacy_schema_aliasing is not None: + util.warn_deprecated( + "The legacy_schema_aliasing parameter is " + "deprecated and will be removed in a future release.", + "1.4", + ) + self.legacy_schema_aliasing = legacy_schema_aliasing + + super(MSDialect, self).__init__(**opts) + + self.isolation_level = isolation_level + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + + def do_savepoint(self, connection, name): + # give the DBAPI a push + connection.exec_driver_sql("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + super(MSDialect, self).do_savepoint(connection, name) + + def do_release_savepoint(self, connection, name): + # SQL Server does not support RELEASE SAVEPOINT + pass + + def do_rollback(self, dbapi_connection): + try: + super(MSDialect, self).do_rollback(dbapi_connection) + except self.dbapi.ProgrammingError as e: + if self.ignore_no_transaction_on_rollback and re.match( + r".*\b111214\b", str(e) + ): + util.warn( + "ProgrammingError 111214 " + "'No corresponding transaction found.' " + "has been suppressed via " + "ignore_no_transaction_on_rollback=True" + ) + else: + raise + + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + ] + ) + + def set_isolation_level(self, connection, level): + level = level.replace("_", " ") + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level) + cursor.close() + if level == "SNAPSHOT": + connection.commit() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + try: + cursor.execute( + "SELECT name FROM sys.system_views WHERE name IN " + "('dm_exec_sessions', 'dm_pdw_nodes_exec_sessions')" + ) + row = cursor.fetchone() + if not row: + raise NotImplementedError( + "Can't fetch isolation level on this particular " + "SQL Server version." + ) + + view_name = "sys.{}".format(row[0]) + cursor.execute( + """ + SELECT CASE transaction_isolation_level + WHEN 0 THEN NULL + WHEN 1 THEN 'READ UNCOMMITTED' + WHEN 2 THEN 'READ COMMITTED' + WHEN 3 THEN 'REPEATABLE READ' + WHEN 4 THEN 'SERIALIZABLE' + WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL + FROM {} + where session_id = @@SPID + """.format( + view_name + ) + ) + row = cursor.fetchone() + assert row is not None + val = row[0] + finally: + cursor.close() + return val.upper() + + def initialize(self, connection): + super(MSDialect, self).initialize(connection) + self._setup_version_attributes() + self._setup_supports_nvarchar_max(connection) + + def on_connect(self): + if self.isolation_level is not None: + + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + return connect + else: + return None + + def _setup_version_attributes(self): + if self.server_version_info[0] not in list(range(8, 17)): + util.warn( + "Unrecognized server version info '%s'. Some SQL Server " + "features may not function properly." + % ".".join(str(x) for x in self.server_version_info) + ) + + if self.server_version_info >= MS_2008_VERSION: + self.supports_multivalues_insert = True + if self.deprecate_large_types is None: + self.deprecate_large_types = ( + self.server_version_info >= MS_2012_VERSION + ) + + self._supports_offset_fetch = ( + self.server_version_info and self.server_version_info[0] >= 11 + ) + + def _setup_supports_nvarchar_max(self, connection): + try: + connection.scalar( + sql.text("SELECT CAST('test max support' AS NVARCHAR(max))") + ) + except exc.DBAPIError: + self._supports_nvarchar_max = False + else: + self._supports_nvarchar_max = True + + def _get_default_schema_name(self, connection): + query = sql.text("SELECT schema_name()") + default_schema_name = connection.scalar(query) + if default_schema_name is not None: + # guard against the case where the default_schema_name is being + # fed back into a table reflection function. + return quoted_name(default_schema_name, quote=True) + else: + return self.schema_name + + @_db_plus_owner + def has_table(self, connection, tablename, dbname, owner, schema): + self._ensure_has_table_connection(connection) + if tablename.startswith("#"): # temporary table + tables = ischema.mssql_temp_table_columns + + s = sql.select(tables.c.table_name).where( + tables.c.table_name.like( + self._temp_table_name_like_pattern(tablename) + ) + ) + + # #7168: fetch all (not just first match) in case some other #temp + # table with the same name happens to appear first + table_names = connection.execute(s).scalars().fetchall() + # #6910: verify it's not a temp table from another session + for table_name in table_names: + if bool( + connection.scalar( + text("SELECT object_id(:table_name)"), + {"table_name": "tempdb.dbo.[{}]".format(table_name)}, + ) + ): + return True + else: + return False + else: + tables = ischema.tables + + s = sql.select(tables.c.table_name).where( + sql.and_( + tables.c.table_type == "BASE TABLE", + tables.c.table_name == tablename, + ) + ) + + if owner: + s = s.where(tables.c.table_schema == owner) + + c = connection.execute(s) + + return c.first() is not None + + @_db_plus_owner + def has_sequence(self, connection, sequencename, dbname, owner, schema): + sequences = ischema.sequences + + s = sql.select(sequences.c.sequence_name).where( + sequences.c.sequence_name == sequencename + ) + + if owner: + s = s.where(sequences.c.sequence_schema == owner) + + c = connection.execute(s) + + return c.first() is not None + + @reflection.cache + @_db_plus_owner_listing + def get_sequence_names(self, connection, dbname, owner, schema, **kw): + sequences = ischema.sequences + + s = sql.select(sequences.c.sequence_name) + if owner: + s = s.where(sequences.c.sequence_schema == owner) + + c = connection.execute(s) + + return [row[0] for row in c] + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select(ischema.schemata.c.schema_name).order_by( + ischema.schemata.c.schema_name + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + @_db_plus_owner_listing + def get_table_names(self, connection, dbname, owner, schema, **kw): + tables = ischema.tables + s = ( + sql.select(tables.c.table_name) + .where( + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "BASE TABLE", + ) + ) + .order_by(tables.c.table_name) + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + @_db_plus_owner_listing + def get_view_names(self, connection, dbname, owner, schema, **kw): + tables = ischema.tables + s = ( + sql.select(tables.c.table_name) + .where( + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "VIEW", + ) + ) + .order_by(tables.c.table_name) + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + @reflection.cache + @_db_plus_owner + def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): + filter_definition = ( + "ind.filter_definition" + if self.server_version_info >= MS_2008_VERSION + else "NULL as filter_definition" + ) + rp = connection.execution_options(future_result=True).execute( + sql.text( + "select ind.index_id, ind.is_unique, ind.name, " + "%s " + "from sys.indexes as ind join sys.tables as tab on " + "ind.object_id=tab.object_id " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name = :tabname " + "and sch.name=:schname " + "and ind.is_primary_key=0 and ind.type != 0" + % filter_definition + ) + .bindparams( + sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + .columns(name=sqltypes.Unicode()) + ) + indexes = {} + for row in rp.mappings(): + indexes[row["index_id"]] = { + "name": row["name"], + "unique": row["is_unique"] == 1, + "column_names": [], + "include_columns": [], + } + + if row["filter_definition"] is not None: + indexes[row["index_id"]].setdefault("dialect_options", {})[ + "mssql_where" + ] = row["filter_definition"] + + rp = connection.execution_options(future_result=True).execute( + sql.text( + "select ind_col.index_id, ind_col.object_id, col.name, " + "ind_col.is_included_column " + "from sys.columns as col " + "join sys.tables as tab on tab.object_id=col.object_id " + "join sys.index_columns as ind_col on " + "(ind_col.column_id=col.column_id and " + "ind_col.object_id=tab.object_id) " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name=:tabname " + "and sch.name=:schname" + ) + .bindparams( + sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + .columns(name=sqltypes.Unicode()) + ) + for row in rp.mappings(): + if row["index_id"] in indexes: + if row["is_included_column"]: + indexes[row["index_id"]]["include_columns"].append( + row["name"] + ) + else: + indexes[row["index_id"]]["column_names"].append( + row["name"] + ) + for index_info in indexes.values(): + # NOTE: "root level" include_columns is legacy, now part of + # dialect_options (issue #7382) + index_info.setdefault("dialect_options", {})[ + "mssql_include" + ] = index_info["include_columns"] + + return list(indexes.values()) + + @reflection.cache + @_db_plus_owner + def get_view_definition( + self, connection, viewname, dbname, owner, schema, **kw + ): + rp = connection.execute( + sql.text( + "select definition from sys.sql_modules as mod, " + "sys.views as views, " + "sys.schemas as sch" + " where " + "mod.object_id=views.object_id and " + "views.schema_id=sch.schema_id and " + "views.name=:viewname and sch.name=:schname" + ).bindparams( + sql.bindparam("viewname", viewname, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + ) + + if rp: + view_def = rp.scalar() + return view_def + + def _temp_table_name_like_pattern(self, tablename): + # LIKE uses '%' to match zero or more characters and '_' to match any + # single character. We want to match literal underscores, so T-SQL + # requires that we enclose them in square brackets. + return tablename + ( + ("[_][_][_]%") if not tablename.startswith("##") else "" + ) + + def _get_internal_temp_table_name(self, connection, tablename): + # it's likely that schema is always "dbo", but since we can + # get it here, let's get it. + # see https://stackoverflow.com/questions/8311959/ + # specifying-schema-for-temporary-tables + + try: + return connection.execute( + sql.text( + "select table_schema, table_name " + "from tempdb.information_schema.tables " + "where table_name like :p1" + ), + {"p1": self._temp_table_name_like_pattern(tablename)}, + ).one() + except exc.MultipleResultsFound as me: + util.raise_( + exc.UnreflectableTableError( + "Found more than one temporary table named '%s' in tempdb " + "at this time. Cannot reliably resolve that name to its " + "internal table name." % tablename + ), + replace_context=me, + ) + except exc.NoResultFound as ne: + util.raise_( + exc.NoSuchTableError( + "Unable to find a temporary table named '%s' in tempdb." + % tablename + ), + replace_context=ne, + ) + + @reflection.cache + @_db_plus_owner + def get_columns(self, connection, tablename, dbname, owner, schema, **kw): + is_temp_table = tablename.startswith("#") + if is_temp_table: + owner, tablename = self._get_internal_temp_table_name( + connection, tablename + ) + + columns = ischema.mssql_temp_table_columns + else: + columns = ischema.columns + + computed_cols = ischema.computed_columns + identity_cols = ischema.identity_columns + if owner: + whereclause = sql.and_( + columns.c.table_name == tablename, + columns.c.table_schema == owner, + ) + full_name = columns.c.table_schema + "." + columns.c.table_name + else: + whereclause = columns.c.table_name == tablename + full_name = columns.c.table_name + + join = columns.join( + computed_cols, + onclause=sql.and_( + computed_cols.c.object_id == func.object_id(full_name), + computed_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), + ), + isouter=True, + ).join( + identity_cols, + onclause=sql.and_( + identity_cols.c.object_id == func.object_id(full_name), + identity_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), + ), + isouter=True, + ) + + if self._supports_nvarchar_max: + computed_definition = computed_cols.c.definition + else: + # tds_version 4.2 does not support NVARCHAR(MAX) + computed_definition = sql.cast( + computed_cols.c.definition, NVARCHAR(4000) + ) + + s = ( + sql.select( + columns, + computed_definition, + computed_cols.c.is_persisted, + identity_cols.c.is_identity, + identity_cols.c.seed_value, + identity_cols.c.increment_value, + ) + .where(whereclause) + .select_from(join) + .order_by(columns.c.ordinal_position) + ) + + c = connection.execution_options(future_result=True).execute(s) + + cols = [] + for row in c.mappings(): + name = row[columns.c.column_name] + type_ = row[columns.c.data_type] + nullable = row[columns.c.is_nullable] == "YES" + charlen = row[columns.c.character_maximum_length] + numericprec = row[columns.c.numeric_precision] + numericscale = row[columns.c.numeric_scale] + default = row[columns.c.column_default] + collation = row[columns.c.collation_name] + definition = row[computed_definition] + is_persisted = row[computed_cols.c.is_persisted] + is_identity = row[identity_cols.c.is_identity] + identity_start = row[identity_cols.c.seed_value] + identity_increment = row[identity_cols.c.increment_value] + + coltype = self.ischema_names.get(type_, None) + + kwargs = {} + if coltype in ( + MSString, + MSChar, + MSNVarchar, + MSNChar, + MSText, + MSNText, + MSBinary, + MSVarBinary, + sqltypes.LargeBinary, + ): + if charlen == -1: + charlen = None + kwargs["length"] = charlen + if collation: + kwargs["collation"] = collation + + if coltype is None: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (type_, name) + ) + coltype = sqltypes.NULLTYPE + else: + if issubclass(coltype, sqltypes.Numeric): + kwargs["precision"] = numericprec + + if not issubclass(coltype, sqltypes.Float): + kwargs["scale"] = numericscale + + coltype = coltype(**kwargs) + cdict = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": is_identity is not None, + } + + if definition is not None and is_persisted is not None: + cdict["computed"] = { + "sqltext": definition, + "persisted": is_persisted, + } + + if is_identity is not None: + # identity_start and identity_increment are Decimal or None + if identity_start is None or identity_increment is None: + cdict["identity"] = {} + else: + if isinstance(coltype, sqltypes.BigInteger): + start = compat.long_type(identity_start) + increment = compat.long_type(identity_increment) + elif isinstance(coltype, sqltypes.Integer): + start = int(identity_start) + increment = int(identity_increment) + else: + start = identity_start + increment = identity_increment + + cdict["identity"] = { + "start": start, + "increment": increment, + } + + cols.append(cdict) + + return cols + + @reflection.cache + @_db_plus_owner + def get_pk_constraint( + self, connection, tablename, dbname, owner, schema, **kw + ): + pkeys = [] + TC = ischema.constraints + C = ischema.key_constraints.alias("C") + + # Primary key constraints + s = ( + sql.select( + C.c.column_name, TC.c.constraint_type, C.c.constraint_name + ) + .where( + sql.and_( + TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner, + ), + ) + .order_by(TC.c.constraint_name, C.c.ordinal_position) + ) + c = connection.execution_options(future_result=True).execute(s) + constraint_name = None + for row in c.mappings(): + if "PRIMARY" in row[TC.c.constraint_type.name]: + pkeys.append(row["COLUMN_NAME"]) + if constraint_name is None: + constraint_name = row[C.c.constraint_name.name] + return {"constrained_columns": pkeys, "name": constraint_name} + + @reflection.cache + @_db_plus_owner + def get_foreign_keys( + self, connection, tablename, dbname, owner, schema, **kw + ): + # Foreign key constraints + s = ( + text( + """\ +WITH fk_info AS ( + SELECT + ischema_ref_con.constraint_schema, + ischema_ref_con.constraint_name, + ischema_key_col.ordinal_position, + ischema_key_col.table_schema, + ischema_key_col.table_name, + ischema_ref_con.unique_constraint_schema, + ischema_ref_con.unique_constraint_name, + ischema_ref_con.match_option, + ischema_ref_con.update_rule, + ischema_ref_con.delete_rule, + ischema_key_col.column_name AS constrained_column + FROM + INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS ischema_ref_con + INNER JOIN + INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col ON + ischema_key_col.table_schema = ischema_ref_con.constraint_schema + AND ischema_key_col.constraint_name = + ischema_ref_con.constraint_name + WHERE ischema_key_col.table_name = :tablename + AND ischema_key_col.table_schema = :owner +), +constraint_info AS ( + SELECT + ischema_key_col.constraint_schema, + ischema_key_col.constraint_name, + ischema_key_col.ordinal_position, + ischema_key_col.table_schema, + ischema_key_col.table_name, + ischema_key_col.column_name + FROM + INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col +), +index_info AS ( + SELECT + sys.schemas.name AS index_schema, + sys.indexes.name AS index_name, + sys.index_columns.key_ordinal AS ordinal_position, + sys.schemas.name AS table_schema, + sys.objects.name AS table_name, + sys.columns.name AS column_name + FROM + sys.indexes + INNER JOIN + sys.objects ON + sys.objects.object_id = sys.indexes.object_id + INNER JOIN + sys.schemas ON + sys.schemas.schema_id = sys.objects.schema_id + INNER JOIN + sys.index_columns ON + sys.index_columns.object_id = sys.objects.object_id + AND sys.index_columns.index_id = sys.indexes.index_id + INNER JOIN + sys.columns ON + sys.columns.object_id = sys.indexes.object_id + AND sys.columns.column_id = sys.index_columns.column_id +) + SELECT + fk_info.constraint_schema, + fk_info.constraint_name, + fk_info.ordinal_position, + fk_info.constrained_column, + constraint_info.table_schema AS referred_table_schema, + constraint_info.table_name AS referred_table_name, + constraint_info.column_name AS referred_column, + fk_info.match_option, + fk_info.update_rule, + fk_info.delete_rule + FROM + fk_info INNER JOIN constraint_info ON + constraint_info.constraint_schema = + fk_info.unique_constraint_schema + AND constraint_info.constraint_name = + fk_info.unique_constraint_name + AND constraint_info.ordinal_position = fk_info.ordinal_position + UNION + SELECT + fk_info.constraint_schema, + fk_info.constraint_name, + fk_info.ordinal_position, + fk_info.constrained_column, + index_info.table_schema AS referred_table_schema, + index_info.table_name AS referred_table_name, + index_info.column_name AS referred_column, + fk_info.match_option, + fk_info.update_rule, + fk_info.delete_rule + FROM + fk_info INNER JOIN index_info ON + index_info.index_schema = fk_info.unique_constraint_schema + AND index_info.index_name = fk_info.unique_constraint_name + AND index_info.ordinal_position = fk_info.ordinal_position + + ORDER BY fk_info.constraint_schema, fk_info.constraint_name, + fk_info.ordinal_position +""" + ) + .bindparams( + sql.bindparam("tablename", tablename, ischema.CoerceUnicode()), + sql.bindparam("owner", owner, ischema.CoerceUnicode()), + ) + .columns( + constraint_schema=sqltypes.Unicode(), + constraint_name=sqltypes.Unicode(), + table_schema=sqltypes.Unicode(), + table_name=sqltypes.Unicode(), + constrained_column=sqltypes.Unicode(), + referred_table_schema=sqltypes.Unicode(), + referred_table_name=sqltypes.Unicode(), + referred_column=sqltypes.Unicode(), + ) + ) + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + + def fkey_rec(): + return { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + "options": {}, + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).fetchall(): + ( + _, # constraint schema + rfknm, + _, # ordinal position + scol, + rschema, + rtbl, + rcol, + # TODO: we support match=<keyword> for foreign keys so + # we can support this also, PG has match=FULL for example + # but this seems to not be a valid value for SQL Server + _, # match rule + fkuprule, + fkdelrule, + ) = r + + rec = fkeys[rfknm] + rec["name"] = rfknm + + if fkuprule != "NO ACTION": + rec["options"]["onupdate"] = fkuprule + + if fkdelrule != "NO ACTION": + rec["options"]["ondelete"] = fkdelrule + + if not rec["referred_table"]: + rec["referred_table"] = rtbl + if schema is not None or owner != rschema: + if dbname: + rschema = dbname + "." + rschema + rec["referred_schema"] = rschema + + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) + + local_cols.append(scol) + remote_cols.append(rcol) + + return list(fkeys.values()) diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py new file mode 100644 index 0000000..df91493 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -0,0 +1,232 @@ +# mssql/information_schema.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from ... import cast +from ... import Column +from ... import MetaData +from ... import Table +from ... import util +from ...ext.compiler import compiles +from ...sql import expression +from ...types import Boolean +from ...types import Integer +from ...types import Numeric +from ...types import String +from ...types import TypeDecorator +from ...types import Unicode + + +ischema = MetaData() + + +class CoerceUnicode(TypeDecorator): + impl = Unicode + cache_ok = True + + def process_bind_param(self, value, dialect): + if util.py2k and isinstance(value, util.binary_type): + value = value.decode(dialect.encoding) + return value + + def bind_expression(self, bindvalue): + return _cast_on_2005(bindvalue) + + +class _cast_on_2005(expression.ColumnElement): + def __init__(self, bindvalue): + self.bindvalue = bindvalue + + +@compiles(_cast_on_2005) +def _compile(element, compiler, **kw): + from . import base + + if ( + compiler.dialect.server_version_info is None + or compiler.dialect.server_version_info < base.MS_2005_VERSION + ): + return compiler.process(element.bindvalue, **kw) + else: + return compiler.process(cast(element.bindvalue, Unicode), **kw) + + +schemata = Table( + "SCHEMATA", + ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA", +) + +tables = Table( + "TABLES", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", CoerceUnicode, key="table_type"), + schema="INFORMATION_SCHEMA", +) + +columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA", +) + +mssql_temp_table_columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="tempdb.INFORMATION_SCHEMA", +) + +constraints = Table( + "TABLE_CONSTRAINTS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", CoerceUnicode, key="constraint_type"), + schema="INFORMATION_SCHEMA", +) + +column_constraints = Table( + "CONSTRAINT_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA", +) + +key_constraints = Table( + "KEY_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA", +) + +ref_constraints = Table( + "REFERENTIAL_CONSTRAINTS", + ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + # TODO: is CATLOG misspelled ? + Column( + "UNIQUE_CONSTRAINT_CATLOG", + CoerceUnicode, + key="unique_constraint_catalog", + ), + Column( + "UNIQUE_CONSTRAINT_SCHEMA", + CoerceUnicode, + key="unique_constraint_schema", + ), + Column( + "UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name" + ), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA", +) + +views = Table( + "VIEWS", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA", +) + +computed_columns = Table( + "computed_columns", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("is_computed", Boolean), + Column("is_persisted", Boolean), + Column("definition", CoerceUnicode), + schema="sys", +) + +sequences = Table( + "SEQUENCES", + ischema, + Column("SEQUENCE_CATALOG", CoerceUnicode, key="sequence_catalog"), + Column("SEQUENCE_SCHEMA", CoerceUnicode, key="sequence_schema"), + Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"), + schema="INFORMATION_SCHEMA", +) + + +class IdentitySqlVariant(TypeDecorator): + r"""This type casts sql_variant columns in the identity_columns view + to numeric. This is required because: + + * pyodbc does not support sql_variant + * pymssql under python 2 return the byte representation of the number, + int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the + correct value as string. + """ + impl = Unicode + cache_ok = True + + def column_expression(self, colexpr): + return cast(colexpr, Numeric) + + +identity_columns = Table( + "identity_columns", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("is_identity", Boolean), + Column("seed_value", IdentitySqlVariant), + Column("increment_value", IdentitySqlVariant), + Column("last_value", IdentitySqlVariant), + Column("is_not_for_replication", Boolean), + schema="sys", +) diff --git a/lib/sqlalchemy/dialects/mssql/json.py b/lib/sqlalchemy/dialects/mssql/json.py new file mode 100644 index 0000000..d515731 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/json.py @@ -0,0 +1,125 @@ +from ... import types as sqltypes + +# technically, all the dialect-specific datatypes that don't have any special +# behaviors would be private with names like _MSJson. However, we haven't been +# doing this for mysql.JSON or sqlite.JSON which both have JSON / JSONIndexType +# / JSONPathType in their json.py files, so keep consistent with that +# sub-convention for now. A future change can update them all to be +# package-private at once. + + +class JSON(sqltypes.JSON): + """MSSQL JSON type. + + MSSQL supports JSON-formatted data as of SQL Server 2016. + + The :class:`_mssql.JSON` datatype at the DDL level will represent the + datatype as ``NVARCHAR(max)``, but provides for JSON-level comparison + functions as well as Python coercion behavior. + + :class:`_mssql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a SQL Server backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`_mssql.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_VALUE`` + or ``JSON_QUERY`` functions at the database level. + + The SQL Server :class:`_mssql.JSON` type necessarily makes use of the + ``JSON_QUERY`` and ``JSON_VALUE`` functions when querying for elements + of a JSON object. These two functions have a major restriction in that + they are **mutually exclusive** based on the type of object to be returned. + The ``JSON_QUERY`` function **only** returns a JSON dictionary or list, + but not an individual string, numeric, or boolean element; the + ``JSON_VALUE`` function **only** returns an individual string, numeric, + or boolean element. **both functions either return NULL or raise + an error if they are not used against the correct expected value**. + + To handle this awkward requirement, indexed access rules are as follows: + + 1. When extracting a sub element from a JSON that is itself a JSON + dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor + should be used:: + + stmt = select( + data_table.c.data["some key"].as_json() + ).where( + data_table.c.data["some key"].as_json() == {"sub": "structure"} + ) + + 2. When extracting a sub element from a JSON that is a plain boolean, + string, integer, or float, use the appropriate method among + :meth:`_types.JSON.Comparator.as_boolean`, + :meth:`_types.JSON.Comparator.as_string`, + :meth:`_types.JSON.Comparator.as_integer`, + :meth:`_types.JSON.Comparator.as_float`:: + + stmt = select( + data_table.c.data["some key"].as_string() + ).where( + data_table.c.data["some key"].as_string() == "some string" + ) + + .. versionadded:: 1.4 + + + """ + + # note there was a result processor here that was looking for "number", + # but none of the tests seem to exercise it. + + +# Note: these objects currently match exactly those of MySQL, however since +# these are not generalizable to all JSON implementations, remain separately +# implemented for each dialect. +class _FormatTypeMixin(object): + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py new file mode 100644 index 0000000..95c32d4 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -0,0 +1,150 @@ +# mssql/mxodbc.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mssql+mxodbc + :name: mxODBC + :dbapi: mxodbc + :connectstring: mssql+mxodbc://<username>:<password>@<dsnname> + :url: https://www.egenix.com/ + +.. deprecated:: 1.4 The mxODBC DBAPI is deprecated and will be removed + in a future version. Please use one of the supported DBAPIs to + connect to mssql. + +Execution Modes +--------------- + +mxODBC features two styles of statement execution, using the +``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being +an extension to the DBAPI specification). The former makes use of a particular +API call specific to the SQL Server Native Client ODBC driver known +SQLDescribeParam, while the latter does not. + +mxODBC apparently only makes repeated use of a single prepared statement +when SQLDescribeParam is used. The advantage to prepared statement reuse is +one of performance. The disadvantage is that SQLDescribeParam has a limited +set of scenarios in which bind parameters are understood, including that they +cannot be placed within the argument lists of function calls, anywhere outside +the FROM, or even within subqueries within the FROM clause - making the usage +of bind parameters within SELECT statements impossible for all but the most +simplistic statements. + +For this reason, the mxODBC dialect uses the "native" mode by default only for +INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for +all other statements. + +This behavior can be controlled via +:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the +``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a +value of ``True`` will unconditionally use native bind parameters and a value +of ``False`` will unconditionally use string-escaped parameters. + +""" + + +from .base import _MSDate +from .base import _MSDateTime +from .base import _MSTime +from .base import MSDialect +from .base import VARBINARY +from .pyodbc import _MSNumeric_pyodbc +from .pyodbc import MSExecutionContext_pyodbc +from ... import types as sqltypes +from ...connectors.mxodbc import MxODBCConnector + + +class _MSNumeric_mxodbc(_MSNumeric_pyodbc): + """Include pyodbc's numeric processor.""" + + +class _MSDate_mxodbc(_MSDate): + def bind_processor(self, dialect): + def process(value): + if value is not None: + return "%s-%s-%s" % (value.year, value.month, value.day) + else: + return None + + return process + + +class _MSTime_mxodbc(_MSTime): + def bind_processor(self, dialect): + def process(value): + if value is not None: + return "%s:%s:%s" % (value.hour, value.minute, value.second) + else: + return None + + return process + + +class _VARBINARY_mxodbc(VARBINARY): + + """ + mxODBC Support for VARBINARY column types. + + This handles the special case for null VARBINARY values, + which maps None values to the mx.ODBC.Manager.BinaryNull symbol. + """ + + def bind_processor(self, dialect): + if dialect.dbapi is None: + return None + + DBAPIBinary = dialect.dbapi.Binary + + def process(value): + if value is not None: + return DBAPIBinary(value) + else: + # should pull from mx.ODBC.Manager.BinaryNull + return dialect.dbapi.BinaryNull + + return process + + +class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): + """ + The pyodbc execution context is useful for enabling + SELECT SCOPE_IDENTITY in cases where OUTPUT clause + does not work (tables with insert triggers). + """ + + # todo - investigate whether the pyodbc execution context + # is really only being used in cases where OUTPUT + # won't work. + + +class MSDialect_mxodbc(MxODBCConnector, MSDialect): + + # this is only needed if "native ODBC" mode is used, + # which is now disabled by default. + # statement_compiler = MSSQLStrictCompiler + supports_statement_cache = True + + execution_ctx_cls = MSExecutionContext_mxodbc + + # flag used by _MSNumeric_mxodbc + _need_decimal_fix = True + + colspecs = { + sqltypes.Numeric: _MSNumeric_mxodbc, + sqltypes.DateTime: _MSDateTime, + sqltypes.Date: _MSDate_mxodbc, + sqltypes.Time: _MSTime_mxodbc, + VARBINARY: _VARBINARY_mxodbc, + sqltypes.LargeBinary: _VARBINARY_mxodbc, + } + + def __init__(self, description_encoding=None, **params): + super(MSDialect_mxodbc, self).__init__(**params) + self.description_encoding = description_encoding + + +dialect = MSDialect_mxodbc diff --git a/lib/sqlalchemy/dialects/mssql/provision.py b/lib/sqlalchemy/dialects/mssql/provision.py new file mode 100644 index 0000000..56f3305 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/provision.py @@ -0,0 +1,116 @@ +from sqlalchemy import inspect +from sqlalchemy import Integer +from ... import create_engine +from ... import exc +from ...schema import Column +from ...schema import DropConstraint +from ...schema import ForeignKeyConstraint +from ...schema import MetaData +from ...schema import Table +from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_pre_tables +from ...testing.provision import drop_db +from ...testing.provision import get_temp_table_name +from ...testing.provision import log +from ...testing.provision import run_reap_dbs +from ...testing.provision import temp_table_keyword_args + + +@create_db.for_db("mssql") +def _mssql_create_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + conn.exec_driver_sql("create database %s" % ident) + conn.exec_driver_sql( + "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident + ) + conn.exec_driver_sql( + "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident + ) + conn.exec_driver_sql("use %s" % ident) + conn.exec_driver_sql("create schema test_schema") + conn.exec_driver_sql("create schema test_schema_2") + + +@drop_db.for_db("mssql") +def _mssql_drop_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + _mssql_drop_ignore(conn, ident) + + +def _mssql_drop_ignore(conn, ident): + try: + # typically when this happens, we can't KILL the session anyway, + # so let the cleanup process drop the DBs + # for row in conn.exec_driver_sql( + # "select session_id from sys.dm_exec_sessions " + # "where database_id=db_id('%s')" % ident): + # log.info("killing SQL server session %s", row['session_id']) + # conn.exec_driver_sql("kill %s" % row['session_id']) + conn.exec_driver_sql("drop database %s" % ident) + log.info("Reaped db: %s", ident) + return True + except exc.DatabaseError as err: + log.warning("couldn't drop db: %s", err) + return False + + +@run_reap_dbs.for_db("mssql") +def _reap_mssql_dbs(url, idents): + log.info("db reaper connecting to %r", url) + eng = create_engine(url) + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + + log.info("identifiers in file: %s", ", ".join(idents)) + + to_reap = conn.exec_driver_sql( + "select d.name from sys.databases as d where name " + "like 'TEST_%' and not exists (select session_id " + "from sys.dm_exec_sessions " + "where database_id=d.database_id)" + ) + all_names = {dbname.lower() for (dbname,) in to_reap} + to_drop = set() + for name in all_names: + if name in idents: + to_drop.add(name) + + dropped = total = 0 + for total, dbname in enumerate(to_drop, 1): + if _mssql_drop_ignore(conn, dbname): + dropped += 1 + log.info( + "Dropped %d out of %d stale databases detected", dropped, total + ) + + +@temp_table_keyword_args.for_db("mssql") +def _mssql_temp_table_keyword_args(cfg, eng): + return {} + + +@get_temp_table_name.for_db("mssql") +def _mssql_get_temp_table_name(cfg, eng, base_name): + return "##" + base_name + + +@drop_all_schema_objects_pre_tables.for_db("mssql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + inspector = inspect(conn) + for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2): + for tname in inspector.get_table_names(schema=schema): + tb = Table( + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + for fk in inspect(conn).get_foreign_keys(tname, schema=schema): + conn.execute( + DropConstraint( + ForeignKeyConstraint( + [tb.c.x], [tb.c.y], name=fk["name"] + ) + ) + ) diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 0000000..84c5fed --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,138 @@ +# mssql/pymssql.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: mssql+pymssql + :name: pymssql + :dbapi: pymssql + :connectstring: mssql+pymssql://<username>:<password>@<freetds_name>/?charset=utf8 + +pymssql is a Python module that provides a Python DBAPI interface around +`FreeTDS <https://www.freetds.org/>`_. + +.. note:: + + pymssql is currently not included in SQLAlchemy's continuous integration + (CI) testing. + +Modern versions of this driver worked very well with SQL Server and FreeTDS +from Linux and were highly recommended. However, pymssql is currently +unmaintained and has fallen behind the progress of the Microsoft ODBC driver in +its support for newer features of SQL Server. The latest official release of +pymssql at the time of this document is version 2.1.4 (August, 2018) and it +lacks support for: + +1. table-valued parameters (TVPs), +2. ``datetimeoffset`` columns using timezone-aware ``datetime`` objects + (values are sent and retrieved as strings), and +3. encrypted connections (e.g., to Azure SQL), when pymssql is installed from + the pre-built wheels. Support for encrypted connections requires building + pymssql from source, which can be a nuisance, especially under Windows. + +The above features are all supported by mssql+pyodbc when using Microsoft's +ODBC Driver for SQL Server (msodbcsql), which is now available for Windows, +(several flavors of) Linux, and macOS. + + +""" # noqa +import re + +from .base import MSDialect +from .base import MSIdentifierPreparer +from ... import processors +from ... import types as sqltypes +from ... import util + + +class _MSNumeric_pymssql(sqltypes.Numeric): + def result_processor(self, dialect, type_): + if not self.asdecimal: + return processors.to_float + else: + return sqltypes.Numeric.result_processor(self, dialect, type_) + + +class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): + def __init__(self, dialect): + super(MSIdentifierPreparer_pymssql, self).__init__(dialect) + # pymssql has the very unusual behavior that it uses pyformat + # yet does not require that percent signs be doubled + self._double_percents = False + + +class MSDialect_pymssql(MSDialect): + supports_statement_cache = True + supports_native_decimal = True + driver = "pymssql" + + preparer = MSIdentifierPreparer_pymssql + + colspecs = util.update_copy( + MSDialect.colspecs, + {sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float}, + ) + + @classmethod + def dbapi(cls): + module = __import__("pymssql") + # pymmsql < 2.1.1 doesn't have a Binary method. we use string + client_ver = tuple(int(x) for x in module.__version__.split(".")) + if client_ver < (2, 1, 1): + # TODO: monkeypatching here is less than ideal + module.Binary = lambda x: x if hasattr(x, "decode") else str(x) + + if client_ver < (1,): + util.warn( + "The pymssql dialect expects at least " + "the 1.0 series of the pymssql DBAPI." + ) + return module + + def _get_server_version_info(self, connection): + vers = connection.exec_driver_sql("select @@version").scalar() + m = re.match(r"Microsoft .*? - (\d+)\.(\d+)\.(\d+)\.(\d+)", vers) + if m: + return tuple(int(x) for x in m.group(1, 2, 3, 4)) + else: + return None + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + opts.update(url.query) + port = opts.pop("port", None) + if port and "host" in opts: + opts["host"] = "%s:%s" % (opts["host"], port) + return [[], opts] + + def is_disconnect(self, e, connection, cursor): + for msg in ( + "Adaptive Server connection timed out", + "Net-Lib error during Connection reset by peer", + "message 20003", # connection timeout + "Error 10054", + "Not connected to any MS SQL server", + "Connection is closed", + "message 20006", # Write to the server failed + "message 20017", # Unexpected EOF from the server + "message 20047", # DBPROCESS is dead or not enabled + ): + if msg in str(e): + return True + else: + return False + + def set_isolation_level(self, connection, level): + if level == "AUTOCOMMIT": + connection.autocommit(True) + else: + connection.autocommit(False) + super(MSDialect_pymssql, self).set_isolation_level( + connection, level + ) + + +dialect = MSDialect_pymssql diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 0000000..edb76f2 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,673 @@ +# mssql/pyodbc.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: mssql+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: mssql+pyodbc://<username>:<password>@<dsnname> + :url: https://pypi.org/project/pyodbc/ + +Connecting to PyODBC +-------------------- + +The URL here is to be translated to PyODBC connection strings, as +detailed in `ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_. + +DSN Connections +^^^^^^^^^^^^^^^ + +A DSN connection in ODBC means that a pre-existing ODBC datasource is +configured on the client machine. The application then specifies the name +of this datasource, which encompasses details such as the specific ODBC driver +in use as well as the network address of the database. Assuming a datasource +is configured on the client, a basic DSN-based connection looks like:: + + engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn") + +Which above, will pass the following connection string to PyODBC:: + + DSN=some_dsn;UID=scott;PWD=tiger + +If the username and password are omitted, the DSN form will also add +the ``Trusted_Connection=yes`` directive to the ODBC string. + +Hostname Connections +^^^^^^^^^^^^^^^^^^^^ + +Hostname-based connections are also supported by pyodbc. These are often +easier to use than a DSN and have the additional advantage that the specific +database name to connect towards may be specified locally in the URL, rather +than it being fixed as part of a datasource configuration. + +When using a hostname connection, the driver name must also be specified in the +query parameters of the URL. As these names usually have spaces in them, the +name must be URL encoded which means using plus signs for spaces:: + + engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server") + +Other keywords interpreted by the Pyodbc dialect to be passed to +``pyodbc.connect()`` in both the DSN and hostname cases include: +``odbc_autotranslate``, ``ansi``, ``unicode_results``, ``autocommit``, +``authentication``. +Note that in order for the dialect to recognize these keywords +(including the ``driver`` keyword above) they must be all lowercase. +Multiple additional keyword arguments must be separated by an +ampersand (``&``), not a semicolon:: + + engine = create_engine( + "mssql+pyodbc://scott:tiger@myhost:49242/databasename" + "?driver=ODBC+Driver+17+for+SQL+Server" + "&authentication=ActiveDirectoryIntegrated" + ) + +The equivalent URL can be constructed using :class:`_sa.engine.URL`:: + + from sqlalchemy.engine import URL + connection_url = URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="myhost", + port=49242, + database="databasename", + query={ + "driver": "ODBC Driver 17 for SQL Server", + "authentication": "ActiveDirectoryIntegrated", + }, + ) + + +Pass through exact Pyodbc string +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A PyODBC connection string can also be sent in pyodbc's format directly, as +specified in `the PyODBC documentation +<https://github.com/mkleehammer/pyodbc/wiki/Connecting-to-databases>`_, +using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object +can help make this easier:: + + from sqlalchemy.engine import URL + connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password" + connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string}) + + engine = create_engine(connection_url) + +.. _mssql_pyodbc_access_tokens: + +Connecting to databases with access tokens +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Some database servers are set up to only accept access tokens for login. For +example, SQL Server allows the use of Azure Active Directory tokens to connect +to databases. This requires creating a credential object using the +``azure-identity`` library. More information about the authentication step can be +found in `Microsoft's documentation +<https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=bash>`_. + +After getting an engine, the credentials need to be sent to ``pyodbc.connect`` +each time a connection is requested. One way to do this is to set up an event +listener on the engine that adds the credential token to the dialect's connect +call. This is discussed more generally in :ref:`engines_dynamic_tokens`. For +SQL Server in particular, this is passed as an ODBC connection attribute with +a data structure `described by Microsoft +<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_. + +The following code snippet will create an engine that connects to an Azure SQL +database using Azure credentials:: + + import struct + from sqlalchemy import create_engine, event + from sqlalchemy.engine.url import URL + from azure import identity + + SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h + TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database + + connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server" + + engine = create_engine(connection_string) + + azure_credentials = identity.DefaultAzureCredential() + + @event.listens_for(engine, "do_connect") + def provide_token(dialect, conn_rec, cargs, cparams): + # remove the "Trusted_Connection" parameter that SQLAlchemy adds + cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "") + + # create token credential + raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le") + token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token) + + # apply it to keyword arguments + cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct} + +.. tip:: + + The ``Trusted_Connection`` token is currently added by the SQLAlchemy + pyodbc dialect when no username or password is present. This needs + to be removed per Microsoft's + `documentation for Azure access tokens + <https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_, + stating that a connection string when using an access token must not contain + ``UID``, ``PWD``, ``Authentication`` or ``Trusted_Connection`` parameters. + +.. _azure_synapse_ignore_no_transaction_on_rollback: + +Avoiding transaction-related exceptions on Azure Synapse Analytics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Azure Synapse Analytics has a significant difference in its transaction +handling compared to plain SQL Server; in some cases an error within a Synapse +transaction can cause it to be arbitrarily terminated on the server side, which +then causes the DBAPI ``.rollback()`` method (as well as ``.commit()``) to +fail. The issue prevents the usual DBAPI contract of allowing ``.rollback()`` +to pass silently if no transaction is present as the driver does not expect +this condition. The symptom of this failure is an exception with a message +resembling 'No corresponding transaction found. (111214)' when attempting to +emit a ``.rollback()`` after an operation had a failure of some kind. + +This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to +the SQL Server dialect via the :func:`_sa.create_engine` function as follows:: + + engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True) + +Using the above parameter, the dialect will catch ``ProgrammingError`` +exceptions raised during ``connection.rollback()`` and emit a warning +if the error message contains code ``111214``, however will not raise +an exception. + +.. versionadded:: 1.4.40 Added the + ``ignore_no_transaction_on_rollback=True`` parameter. + +Enable autocommit for Azure SQL Data Warehouse (DW) connections +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Azure SQL Data Warehouse does not support transactions, +and that can cause problems with SQLAlchemy's "autobegin" (and implicit +commit/rollback) behavior. We can avoid these problems by enabling autocommit +at both the pyodbc and engine levels:: + + connection_url = sa.engine.URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="dw.azure.example.com", + database="mydb", + query={ + "driver": "ODBC Driver 17 for SQL Server", + "autocommit": "True", + }, + ) + + engine = create_engine(connection_url).execution_options( + isolation_level="AUTOCOMMIT" + ) + +Avoiding sending large string parameters as TEXT/NTEXT +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, for historical reasons, Microsoft's ODBC drivers for SQL Server +send long string parameters (greater than 4000 SBCS characters or 2000 Unicode +characters) as TEXT/NTEXT values. TEXT and NTEXT have been deprecated for many +years and are starting to cause compatibility issues with newer versions of +SQL_Server/Azure. For example, see `this +issue <https://github.com/mkleehammer/pyodbc/issues/835>`_. + +Starting with ODBC Driver 18 for SQL Server we can override the legacy +behavior and pass long strings as varchar(max)/nvarchar(max) using the +``LongAsMax=Yes`` connection string parameter:: + + connection_url = sa.engine.URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="mssqlserver.example.com", + database="mydb", + query={ + "driver": "ODBC Driver 18 for SQL Server", + "LongAsMax": "Yes", + }, + ) + + +Pyodbc Pooling / connection close behavior +------------------------------------------ + +PyODBC uses internal `pooling +<https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ by +default, which means connections will be longer lived than they are within +SQLAlchemy itself. As SQLAlchemy has its own pooling behavior, it is often +preferable to disable this behavior. This behavior can only be disabled +globally at the PyODBC module level, **before** any connections are made:: + + import pyodbc + + pyodbc.pooling = False + + # don't use the engine before pooling is set to False + engine = create_engine("mssql+pyodbc://user:pass@dsn") + +If this variable is left at its default value of ``True``, **the application +will continue to maintain active database connections**, even when the +SQLAlchemy engine itself fully discards a connection or if the engine is +disposed. + +.. seealso:: + + `pooling <https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ - + in the PyODBC documentation. + +Driver / Unicode Support +------------------------- + +PyODBC works best with Microsoft ODBC drivers, particularly in the area +of Unicode support on both Python 2 and Python 3. + +Using the FreeTDS ODBC drivers on Linux or OSX with PyODBC is **not** +recommended; there have been historically many Unicode-related issues +in this area, including before Microsoft offered ODBC drivers for Linux +and OSX. Now that Microsoft offers drivers for all platforms, for +PyODBC support these are recommended. FreeTDS remains relevant for +non-ODBC drivers such as pymssql where it works very well. + + +Rowcount Support +---------------- + +Pyodbc only has partial support for rowcount. See the notes at +:ref:`mssql_rowcount_versioning` for important notes when using ORM +versioning. + +.. _mssql_pyodbc_fastexecutemany: + +Fast Executemany Mode +--------------------- + +The Pyodbc driver has added support for a "fast executemany" mode of execution +which greatly reduces round trips for a DBAPI ``executemany()`` call when using +Microsoft ODBC drivers, for **limited size batches that fit in memory**. The +feature is enabled by setting the flag ``.fast_executemany`` on the DBAPI +cursor when an executemany call is to be used. The SQLAlchemy pyodbc SQL +Server dialect supports setting this flag automatically when the +``.fast_executemany`` flag is passed to +:func:`_sa.create_engine` ; note that the ODBC driver must be the Microsoft +driver in order to use this flag:: + + engine = create_engine( + "mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server", + fast_executemany=True) + +.. warning:: The pyodbc fast_executemany mode **buffers all rows in memory** and is + not compatible with very large batches of data. A future version of SQLAlchemy + may support this flag as a per-execution option instead. + +.. versionadded:: 1.3 + +.. seealso:: + + `fast executemany <https://github.com/mkleehammer/pyodbc/wiki/Features-beyond-the-DB-API#fast_executemany>`_ + - on github + +.. _mssql_pyodbc_setinputsizes: + +Setinputsizes Support +----------------------- + +The pyodbc ``cursor.setinputsizes()`` method can be used if necessary. To +enable this hook, pass ``use_setinputsizes=True`` to :func:`_sa.create_engine`:: + + engine = create_engine("mssql+pyodbc://...", use_setinputsizes=True) + +The behavior of the hook can then be customized, as may be necessary +particularly if fast_executemany is in use, via the +:meth:`.DialectEvents.do_setinputsizes` hook. See that method for usage +examples. + +.. versionchanged:: 1.4.1 The pyodbc dialects will not use setinputsizes + unless ``use_setinputsizes=True`` is passed. + +""" # noqa + + +import datetime +import decimal +import re +import struct + +from .base import BINARY +from .base import DATETIMEOFFSET +from .base import MSDialect +from .base import MSExecutionContext +from .base import VARBINARY +from ... import exc +from ... import types as sqltypes +from ... import util +from ...connectors.pyodbc import PyODBCConnector + + +class _ms_numeric_pyodbc(object): + + """Turns Decimals with adjusted() < 0 or > 7 into strings. + + The routines here are needed for older pyodbc versions + as well as current mxODBC versions. + + """ + + def bind_processor(self, dialect): + + super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect) + + if not dialect._need_decimal_fix: + return super_process + + def process(value): + if self.asdecimal and isinstance(value, decimal.Decimal): + adjusted = value.adjusted() + if adjusted < 0: + return self._small_dec_to_string(value) + elif adjusted > 7: + return self._large_dec_to_string(value) + + if super_process: + return super_process(value) + else: + return value + + return process + + # these routines needed for older versions of pyodbc. + # as of 2.1.8 this logic is integrated. + + def _small_dec_to_string(self, value): + return "%s0.%s%s" % ( + (value < 0 and "-" or ""), + "0" * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value.as_tuple()[1]]), + ) + + def _large_dec_to_string(self, value): + _int = value.as_tuple()[1] + if "E" in str(value): + result = "%s%s%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int]), + "0" * (value.adjusted() - (len(_int) - 1)), + ) + else: + if (len(_int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + "".join([str(s) for s in _int][value.adjusted() + 1 :]), + ) + else: + result = "%s%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + ) + return result + + +class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric): + pass + + +class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float): + pass + + +class _ms_binary_pyodbc(object): + """Wraps binary values in dialect-specific Binary wrapper. + If the value is null, return a pyodbc-specific BinaryNull + object to prevent pyODBC [and FreeTDS] from defaulting binary + NULL types to SQLWCHAR and causing implicit conversion errors. + """ + + def bind_processor(self, dialect): + if dialect.dbapi is None: + return None + + DBAPIBinary = dialect.dbapi.Binary + + def process(value): + if value is not None: + return DBAPIBinary(value) + else: + # pyodbc-specific + return dialect.dbapi.BinaryNull + + return process + + +class _ODBCDateTimeBindProcessor(object): + """Add bind processors to handle datetimeoffset behaviors""" + + has_tz = False + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, util.string_types): + # if a string was passed directly, allow it through + return value + elif not value.tzinfo or (not self.timezone and not self.has_tz): + # for DateTime(timezone=False) + return value + else: + # for DATETIMEOFFSET or DateTime(timezone=True) + # + # Convert to string format required by T-SQL + dto_string = value.strftime("%Y-%m-%d %H:%M:%S.%f %z") + # offset needs a colon, e.g., -0700 -> -07:00 + # "UTC offset in the form (+-)HHMM[SS[.ffffff]]" + # backend currently rejects seconds / fractional seconds + dto_string = re.sub( + r"([\+\-]\d{2})([\d\.]+)$", r"\1:\2", dto_string + ) + return dto_string + + return process + + +class _ODBCDateTime(_ODBCDateTimeBindProcessor, sqltypes.DateTime): + pass + + +class _ODBCDATETIMEOFFSET(_ODBCDateTimeBindProcessor, DATETIMEOFFSET): + has_tz = True + + +class _VARBINARY_pyodbc(_ms_binary_pyodbc, VARBINARY): + pass + + +class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY): + pass + + +class MSExecutionContext_pyodbc(MSExecutionContext): + _embedded_scope_identity = False + + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same + statement. + + Background on why "scope_identity()" is preferable to "@@identity": + https://msdn.microsoft.com/en-us/library/ms190315.aspx + + Background on why we attempt to embed "scope_identity()" into the same + statement as the INSERT: + https://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values? + + """ + + super(MSExecutionContext_pyodbc, self).pre_exec() + + # don't embed the scope_identity select into an + # "INSERT .. DEFAULT VALUES" + if ( + self._select_lastrowid + and self.dialect.use_scope_identity + and len(self.parameters[0]) + ): + self._embedded_scope_identity = True + + self.statement += "; select scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with + # no data (due to triggers, etc.) + while True: + try: + # fetchall() ensures the cursor is consumed + # without closing it (FreeTDS particularly) + row = self.cursor.fetchall()[0] + break + except self.dialect.dbapi.Error: + # no way around this - nextset() consumes the previous set + # so we need to just keep flipping + self.cursor.nextset() + + self._lastrowid = int(row[0]) + else: + super(MSExecutionContext_pyodbc, self).post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + supports_statement_cache = True + + # mssql still has problems with this on Linux + supports_sane_rowcount_returning = False + + execution_ctx_cls = MSExecutionContext_pyodbc + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.Numeric: _MSNumeric_pyodbc, + sqltypes.Float: _MSFloat_pyodbc, + BINARY: _BINARY_pyodbc, + # support DateTime(timezone=True) + sqltypes.DateTime: _ODBCDateTime, + DATETIMEOFFSET: _ODBCDATETIMEOFFSET, + # SQL Server dialect has a VARBINARY that is just to support + # "deprecate_large_types" w/ VARBINARY(max), but also we must + # handle the usual SQL standard VARBINARY + VARBINARY: _VARBINARY_pyodbc, + sqltypes.VARBINARY: _VARBINARY_pyodbc, + sqltypes.LargeBinary: _VARBINARY_pyodbc, + }, + ) + + def __init__( + self, description_encoding=None, fast_executemany=False, **params + ): + if "description_encoding" in params: + self.description_encoding = params.pop("description_encoding") + super(MSDialect_pyodbc, self).__init__(**params) + self.use_scope_identity = ( + self.use_scope_identity + and self.dbapi + and hasattr(self.dbapi.Cursor, "nextset") + ) + self._need_decimal_fix = self.dbapi and self._dbapi_version() < ( + 2, + 1, + 8, + ) + self.fast_executemany = fast_executemany + + def _get_server_version_info(self, connection): + try: + # "Version of the instance of SQL Server, in the form + # of 'major.minor.build.revision'" + raw = connection.exec_driver_sql( + "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)" + ).scalar() + except exc.DBAPIError: + # SQL Server docs indicate this function isn't present prior to + # 2008. Before we had the VARCHAR cast above, pyodbc would also + # fail on this query. + return super(MSDialect_pyodbc, self)._get_server_version_info( + connection, allow_chars=False + ) + else: + version = [] + r = re.compile(r"[.\-]") + for n in r.split(raw): + try: + version.append(int(n)) + except ValueError: + pass + return tuple(version) + + def on_connect(self): + super_ = super(MSDialect_pyodbc, self).on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + self._setup_timestampoffset_type(conn) + + return on_connect + + def _setup_timestampoffset_type(self, connection): + # output converter function for datetimeoffset + def _handle_datetimeoffset(dto_value): + tup = struct.unpack("<6hI2h", dto_value) + return datetime.datetime( + tup[0], + tup[1], + tup[2], + tup[3], + tup[4], + tup[5], + tup[6] // 1000, + util.timezone( + datetime.timedelta(hours=tup[7], minutes=tup[8]) + ), + ) + + odbc_SQL_SS_TIMESTAMPOFFSET = -155 # as defined in SQLNCLI.h + connection.add_output_converter( + odbc_SQL_SS_TIMESTAMPOFFSET, _handle_datetimeoffset + ) + + def do_executemany(self, cursor, statement, parameters, context=None): + if self.fast_executemany: + cursor.fast_executemany = True + super(MSDialect_pyodbc, self).do_executemany( + cursor, statement, parameters, context=context + ) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + code = e.args[0] + if code in { + "08S01", + "01000", + "01002", + "08003", + "08007", + "08S02", + "08001", + "HYT00", + "HY010", + "10054", + }: + return True + return super(MSDialect_pyodbc, self).is_disconnect( + e, connection, cursor + ) + + +dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py new file mode 100644 index 0000000..04c83d1 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -0,0 +1,103 @@ +# mysql/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from . import base # noqa +from . import cymysql # noqa +from . import mariadbconnector # noqa +from . import mysqlconnector # noqa +from . import mysqldb # noqa +from . import oursql # noqa +from . import pymysql # noqa +from . import pyodbc # noqa +from .base import BIGINT +from .base import BINARY +from .base import BIT +from .base import BLOB +from .base import BOOLEAN +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DECIMAL +from .base import DOUBLE +from .base import ENUM +from .base import FLOAT +from .base import INTEGER +from .base import JSON +from .base import LONGBLOB +from .base import LONGTEXT +from .base import MEDIUMBLOB +from .base import MEDIUMINT +from .base import MEDIUMTEXT +from .base import NCHAR +from .base import NUMERIC +from .base import NVARCHAR +from .base import REAL +from .base import SET +from .base import SMALLINT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import TINYBLOB +from .base import TINYINT +from .base import TINYTEXT +from .base import VARBINARY +from .base import VARCHAR +from .base import YEAR +from .dml import Insert +from .dml import insert +from .expression import match +from ...util import compat + +if compat.py3k: + from . import aiomysql # noqa + from . import asyncmy # noqa + +# default dialect +base.dialect = dialect = mysqldb.dialect + +__all__ = ( + "BIGINT", + "BINARY", + "BIT", + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "DOUBLE", + "ENUM", + "DECIMAL", + "FLOAT", + "INTEGER", + "INTEGER", + "JSON", + "LONGBLOB", + "LONGTEXT", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "NCHAR", + "NVARCHAR", + "NUMERIC", + "SET", + "SMALLINT", + "REAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "VARBINARY", + "VARCHAR", + "YEAR", + "dialect", + "insert", + "Insert", + "match", +) diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py new file mode 100644 index 0000000..975467c --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -0,0 +1,317 @@ +# mysql/aiomysql.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS +# file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: mysql+aiomysql + :name: aiomysql + :dbapi: aiomysql + :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...] + :url: https://github.com/aio-libs/aiomysql + +.. warning:: The aiomysql dialect is not currently tested as part of + SQLAlchemy’s continuous integration. As of September, 2021 the driver + appears to be unmaintained and no longer functions for Python version 3.10, + and additionally depends on a significantly outdated version of PyMySQL. + Please refer to the :ref:`asyncmy` dialect for current MySQL/MariaDB asyncio + functionality. + +The aiomysql dialect is SQLAlchemy's second Python asyncio dialect. + +Using a special asyncio mediation layer, the aiomysql dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4") + + +""" # noqa + +from .pymysql import MySQLDialect_pymysql +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiomysql_cursor: + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + # see https://github.com/aio-libs/aiomysql/issues/543 + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. + self._rows[:] = [] + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._execute_mutex: + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + + if not self.server_side: + # aiomysql has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._execute_mutex: + return await self._cursor.executemany(operation, seq_of_parameters) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): + __slots__ = () + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor( + adapt_connection.dbapi.aiomysql.SSCursor + ) + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiomysql_connection(AdaptedConnection): + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_connection", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): + self.await_(self._connection.autocommit(value)) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiomysql_ss_cursor(self) + else: + return AsyncAdapt_aiomysql_cursor(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): + # it's not awaitable. + self._connection.close() + + +class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiomysql_dbapi: + def __init__(self, aiomysql, pymysql): + self.aiomysql = aiomysql + self.pymysql = pymysql + self.paramstyle = "format" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + ): + setattr(self, name, getattr(self.aiomysql, name)) + + for name in ( + "NUMBER", + "STRING", + "DATETIME", + "BINARY", + "TIMESTAMP", + "Binary", + ): + setattr(self, name, getattr(self.pymysql, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + if util.asbool(async_fallback): + return AsyncAdaptFallback_aiomysql_connection( + self, + await_fallback(self.aiomysql.connect(*arg, **kw)), + ) + else: + return AsyncAdapt_aiomysql_connection( + self, + await_only(self.aiomysql.connect(*arg, **kw)), + ) + + +class MySQLDialect_aiomysql(MySQLDialect_pymysql): + driver = "aiomysql" + supports_statement_cache = True + + supports_server_side_cursors = True + _sscursor = AsyncAdapt_aiomysql_ss_cursor + + is_async = True + + @classmethod + def dbapi(cls): + return AsyncAdapt_aiomysql_dbapi( + __import__("aiomysql"), __import__("pymysql") + ) + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def create_connect_args(self, url): + return super(MySQLDialect_aiomysql, self).create_connect_args( + url, _translate_args=dict(username="user", database="db") + ) + + def is_disconnect(self, e, connection, cursor): + if super(MySQLDialect_aiomysql, self).is_disconnect( + e, connection, cursor + ): + return True + else: + str_e = str(e).lower() + return "not connected" in str_e + + def _found_rows_client_flag(self): + from pymysql.constants import CLIENT + + return CLIENT.FOUND_ROWS + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = MySQLDialect_aiomysql diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py new file mode 100644 index 0000000..521918a --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -0,0 +1,328 @@ +# mysql/asyncmy.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS +# file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: mysql+asyncmy + :name: asyncmy + :dbapi: asyncmy + :connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...] + :url: https://github.com/long2ice/asyncmy + +.. note:: The asyncmy dialect as of September, 2021 was added to provide + MySQL/MariaDB asyncio compatibility given that the :ref:`aiomysql` database + driver has become unmaintained, however asyncmy is itself very new. + +Using a special asyncio mediation layer, the asyncmy dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4") + + +""" # noqa + +from .pymysql import MySQLDialect_pymysql +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import asynccontextmanager +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_asyncmy_cursor: + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. + self._rows[:] = [] + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._mutex_and_adapt_errors(): + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + + if not self.server_side: + # asyncmy has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._mutex_and_adapt_errors(): + return await self._cursor.executemany(operation, seq_of_parameters) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor): + __slots__ = () + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor( + adapt_connection.dbapi.asyncmy.cursors.SSCursor + ) + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_asyncmy_connection(AdaptedConnection): + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_connection", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + @asynccontextmanager + async def _mutex_and_adapt_errors(self): + async with self._execute_mutex: + try: + yield + except AttributeError: + raise self.dbapi.InternalError( + "network operation failed due to asyncmy attribute error" + ) + + def ping(self, reconnect): + assert not reconnect + return self.await_(self._do_ping()) + + async def _do_ping(self): + async with self._mutex_and_adapt_errors(): + return await self._connection.ping(False) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): + self.await_(self._connection.autocommit(value)) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_asyncmy_ss_cursor(self) + else: + return AsyncAdapt_asyncmy_cursor(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): + # it's not awaitable. + self._connection.close() + + +class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +def _Binary(x): + """Return x as a binary type.""" + return bytes(x) + + +class AsyncAdapt_asyncmy_dbapi: + def __init__(self, asyncmy): + self.asyncmy = asyncmy + self.paramstyle = "format" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + ): + setattr(self, name, getattr(self.asyncmy.errors, name)) + + STRING = util.symbol("STRING") + NUMBER = util.symbol("NUMBER") + BINARY = util.symbol("BINARY") + DATETIME = util.symbol("DATETIME") + TIMESTAMP = util.symbol("TIMESTAMP") + Binary = staticmethod(_Binary) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + if util.asbool(async_fallback): + return AsyncAdaptFallback_asyncmy_connection( + self, + await_fallback(self.asyncmy.connect(*arg, **kw)), + ) + else: + return AsyncAdapt_asyncmy_connection( + self, + await_only(self.asyncmy.connect(*arg, **kw)), + ) + + +class MySQLDialect_asyncmy(MySQLDialect_pymysql): + driver = "asyncmy" + supports_statement_cache = True + + supports_server_side_cursors = True + _sscursor = AsyncAdapt_asyncmy_ss_cursor + + is_async = True + + @classmethod + def dbapi(cls): + return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy")) + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def create_connect_args(self, url): + return super(MySQLDialect_asyncmy, self).create_connect_args( + url, _translate_args=dict(username="user", database="db") + ) + + def is_disconnect(self, e, connection, cursor): + if super(MySQLDialect_asyncmy, self).is_disconnect( + e, connection, cursor + ): + return True + else: + str_e = str(e).lower() + return ( + "not connected" in str_e or "network operation failed" in str_e + ) + + def _found_rows_client_flag(self): + from asyncmy.constants import CLIENT + + return CLIENT.FOUND_ROWS + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = MySQLDialect_asyncmy diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py new file mode 100644 index 0000000..111c63b --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -0,0 +1,3306 @@ +# mysql/base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" + +.. dialect:: mysql + :name: MySQL / MariaDB + :full_support: 5.6, 5.7, 8.0 / 10.4, 10.5 + :normal_support: 5.6+ / 10+ + :best_effort: 5.0.2+ / 5.0.2+ + +Supported Versions and Features +------------------------------- + +SQLAlchemy supports MySQL starting with version 5.0.2 through modern releases, +as well as all modern versions of MariaDB. See the official MySQL +documentation for detailed information about features supported in any given +server release. + +.. versionchanged:: 1.4 minimum MySQL version supported is now 5.0.2. + +MariaDB Support +~~~~~~~~~~~~~~~ + +The MariaDB variant of MySQL retains fundamental compatibility with MySQL's +protocols however the development of these two products continues to diverge. +Within the realm of SQLAlchemy, the two databases have a small number of +syntactical and behavioral differences that SQLAlchemy accommodates automatically. +To connect to a MariaDB database, no changes to the database URL are required:: + + + engine = create_engine("mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") + +Upon first connect, the SQLAlchemy dialect employs a +server version detection scheme that determines if the +backing database reports as MariaDB. Based on this flag, the dialect +can make different choices in those of areas where its behavior +must be different. + +.. _mysql_mariadb_only_mode: + +MariaDB-Only Mode +~~~~~~~~~~~~~~~~~ + +The dialect also supports an **optional** "MariaDB-only" mode of connection, which may be +useful for the case where an application makes use of MariaDB-specific features +and is not compatible with a MySQL database. To use this mode of operation, +replace the "mysql" token in the above URL with "mariadb":: + + engine = create_engine("mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") + +The above engine, upon first connect, will raise an error if the server version +detection detects that the backing database is not MariaDB. + +When using an engine with ``"mariadb"`` as the dialect name, **all mysql-specific options +that include the name "mysql" in them are now named with "mariadb"**. This means +options like ``mysql_engine`` should be named ``mariadb_engine``, etc. Both +"mysql" and "mariadb" options can be used simultaneously for applications that +use URLs with both "mysql" and "mariadb" dialects:: + + my_table = Table( + "mytable", + metadata, + Column("id", Integer, primary_key=True), + Column("textdata", String(50)), + mariadb_engine="InnoDB", + mysql_engine="InnoDB", + ) + + Index( + "textdata_ix", + my_table.c.textdata, + mysql_prefix="FULLTEXT", + mariadb_prefix="FULLTEXT", + ) + +Similar behavior will occur when the above structures are reflected, i.e. the +"mariadb" prefix will be present in the option names when the database URL +is based on the "mariadb" name. + +.. versionadded:: 1.4 Added "mariadb" dialect name supporting "MariaDB-only mode" + for the MySQL dialect. + +.. _mysql_connection_timeouts: + +Connection Timeouts and Disconnects +----------------------------------- + +MySQL / MariaDB feature an automatic connection close behavior, for connections that +have been idle for a fixed period of time, defaulting to eight hours. +To circumvent having this issue, use +the :paramref:`_sa.create_engine.pool_recycle` option which ensures that +a connection will be discarded and replaced with a new one if it has been +present in the pool for a fixed number of seconds:: + + engine = create_engine('mysql+mysqldb://...', pool_recycle=3600) + +For more comprehensive disconnect detection of pooled connections, including +accommodation of server restarts and network issues, a pre-ping approach may +be employed. See :ref:`pool_disconnects` for current approaches. + +.. seealso:: + + :ref:`pool_disconnects` - Background on several techniques for dealing + with timed out connections as well as database restarts. + +.. _mysql_storage_engines: + +CREATE TABLE arguments including Storage Engines +------------------------------------------------ + +Both MySQL's and MariaDB's CREATE TABLE syntax includes a wide array of special options, +including ``ENGINE``, ``CHARSET``, ``MAX_ROWS``, ``ROW_FORMAT``, +``INSERT_METHOD``, and many more. +To accommodate the rendering of these arguments, specify the form +``mysql_argument_name="value"``. For example, to specify a table with +``ENGINE`` of ``InnoDB``, ``CHARSET`` of ``utf8mb4``, and ``KEY_BLOCK_SIZE`` +of ``1024``:: + + Table('mytable', metadata, + Column('data', String(32)), + mysql_engine='InnoDB', + mysql_charset='utf8mb4', + mysql_key_block_size="1024" + ) + +When supporting :ref:`mysql_mariadb_only_mode` mode, similar keys against +the "mariadb" prefix must be included as well. The values can of course +vary independently so that different settings on MySQL vs. MariaDB may +be maintained:: + + # support both "mysql" and "mariadb-only" engine URLs + + Table('mytable', metadata, + Column('data', String(32)), + + mysql_engine='InnoDB', + mariadb_engine='InnoDB', + + mysql_charset='utf8mb4', + mariadb_charset='utf8', + + mysql_key_block_size="1024" + mariadb_key_block_size="1024" + + ) + +The MySQL / MariaDB dialects will normally transfer any keyword specified as +``mysql_keyword_name`` to be rendered as ``KEYWORD_NAME`` in the +``CREATE TABLE`` statement. A handful of these names will render with a space +instead of an underscore; to support this, the MySQL dialect has awareness of +these particular names, which include ``DATA DIRECTORY`` +(e.g. ``mysql_data_directory``), ``CHARACTER SET`` (e.g. +``mysql_character_set``) and ``INDEX DIRECTORY`` (e.g. +``mysql_index_directory``). + +The most common argument is ``mysql_engine``, which refers to the storage +engine for the table. Historically, MySQL server installations would default +to ``MyISAM`` for this value, although newer versions may be defaulting +to ``InnoDB``. The ``InnoDB`` engine is typically preferred for its support +of transactions and foreign keys. + +A :class:`_schema.Table` +that is created in a MySQL / MariaDB database with a storage engine +of ``MyISAM`` will be essentially non-transactional, meaning any +INSERT/UPDATE/DELETE statement referring to this table will be invoked as +autocommit. It also will have no support for foreign key constraints; while +the ``CREATE TABLE`` statement accepts foreign key options, when using the +``MyISAM`` storage engine these arguments are discarded. Reflecting such a +table will also produce no foreign key constraint information. + +For fully atomic transactions as well as support for foreign key +constraints, all participating ``CREATE TABLE`` statements must specify a +transactional engine, which in the vast majority of cases is ``InnoDB``. + + +Case Sensitivity and Table Reflection +------------------------------------- + +Both MySQL and MariaDB have inconsistent support for case-sensitive identifier +names, basing support on specific details of the underlying +operating system. However, it has been observed that no matter +what case sensitivity behavior is present, the names of tables in +foreign key declarations are *always* received from the database +as all-lower case, making it impossible to accurately reflect a +schema where inter-related tables use mixed-case identifier names. + +Therefore it is strongly advised that table names be declared as +all lower case both within SQLAlchemy as well as on the MySQL / MariaDB +database itself, especially if database reflection features are +to be used. + +.. _mysql_isolation_level: + +Transaction Isolation Level +--------------------------- + +All MySQL / MariaDB dialects support setting of transaction isolation level both via a +dialect-specific parameter :paramref:`_sa.create_engine.isolation_level` +accepted +by :func:`_sa.create_engine`, as well as the +:paramref:`.Connection.execution_options.isolation_level` argument as passed to +:meth:`_engine.Connection.execution_options`. +This feature works by issuing the +command ``SET SESSION TRANSACTION ISOLATION LEVEL <level>`` for each new +connection. For the special AUTOCOMMIT isolation level, DBAPI-specific +techniques are used. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "mysql://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) + +Valid values for ``isolation_level`` include: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +The special ``AUTOCOMMIT`` value makes use of the various "autocommit" +attributes provided by specific DBAPIs, and is currently supported by +MySQLdb, MySQL-Client, MySQL-Connector Python, and PyMySQL. Using it, +the database connection will return true for the value of +``SELECT @@autocommit;``. + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +.. seealso:: + + :ref:`dbapi_autocommit` + +AUTO_INCREMENT Behavior +----------------------- + +When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on +the first :class:`.Integer` primary key column which is not marked as a +foreign key:: + + >>> t = Table('mytable', metadata, + ... Column('mytable_id', Integer, primary_key=True) + ... ) + >>> t.create() + CREATE TABLE mytable ( + id INTEGER NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + ) + +You can disable this behavior by passing ``False`` to the +:paramref:`_schema.Column.autoincrement` argument of :class:`_schema.Column`. +This flag +can also be used to enable auto-increment on a secondary column in a +multi-column key for some storage engines:: + + Table('mytable', metadata, + Column('gid', Integer, primary_key=True, autoincrement=False), + Column('id', Integer, primary_key=True) + ) + +.. _mysql_ss_cursors: + +Server Side Cursors +------------------- + +Server-side cursor support is available for the mysqlclient, PyMySQL, +mariadbconnector dialects and may also be available in others. This makes use +of either the "buffered=True/False" flag if available or by using a class such +as ``MySQLdb.cursors.SSCursor`` or ``pymysql.cursors.SSCursor`` internally. + + +Server side cursors are enabled on a per-statement basis by using the +:paramref:`.Connection.execution_options.stream_results` connection execution +option:: + + with engine.connect() as conn: + result = conn.execution_options(stream_results=True).execute(text("select * from table")) + +Note that some kinds of SQL statements may not be supported with +server side cursors; generally, only SQL statements that return rows should be +used with this option. + +.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated + and will be removed in a future release. Please use the + :paramref:`_engine.Connection.stream_results` execution option for + unbuffered cursor support. + +.. seealso:: + + :ref:`engine_stream_results` + +.. _mysql_unicode: + +Unicode +------- + +Charset Selection +~~~~~~~~~~~~~~~~~ + +Most MySQL / MariaDB DBAPIs offer the option to set the client character set for +a connection. This is typically delivered using the ``charset`` parameter +in the URL, such as:: + + e = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") + +This charset is the **client character set** for the connection. Some +MySQL DBAPIs will default this to a value such as ``latin1``, and some +will make use of the ``default-character-set`` setting in the ``my.cnf`` +file as well. Documentation for the DBAPI in use should be consulted +for specific behavior. + +The encoding used for Unicode has traditionally been ``'utf8'``. However, for +MySQL versions 5.5.3 and MariaDB 5.5 on forward, a new MySQL-specific encoding +``'utf8mb4'`` has been introduced, and as of MySQL 8.0 a warning is emitted by +the server if plain ``utf8`` is specified within any server-side directives, +replaced with ``utf8mb3``. The rationale for this new encoding is due to the +fact that MySQL's legacy utf-8 encoding only supports codepoints up to three +bytes instead of four. Therefore, when communicating with a MySQL or MariaDB +database that includes codepoints more than three bytes in size, this new +charset is preferred, if supported by both the database as well as the client +DBAPI, as in:: + + e = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") + +All modern DBAPIs should support the ``utf8mb4`` charset. + +In order to use ``utf8mb4`` encoding for a schema that was created with legacy +``utf8``, changes to the MySQL/MariaDB schema and/or server configuration may be +required. + +.. seealso:: + + `The utf8mb4 Character Set \ + <https://dev.mysql.com/doc/refman/5.5/en/charset-unicode-utf8mb4.html>`_ - \ + in the MySQL documentation + +.. _mysql_binary_introducer: + +Dealing with Binary Data Warnings and Unicode +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL versions 5.6, 5.7 and later (not MariaDB at the time of this writing) now +emit a warning when attempting to pass binary data to the database, while a +character set encoding is also in place, when the binary data itself is not +valid for that encoding:: + + default.py:509: Warning: (1300, "Invalid utf8mb4 character string: + 'F9876A'") + cursor.execute(statement, parameters) + +This warning is due to the fact that the MySQL client library is attempting to +interpret the binary string as a unicode object even if a datatype such +as :class:`.LargeBinary` is in use. To resolve this, the SQL statement requires +a binary "character set introducer" be present before any non-NULL value +that renders like this:: + + INSERT INTO table (data) VALUES (_binary %s) + +These character set introducers are provided by the DBAPI driver, assuming the +use of mysqlclient or PyMySQL (both of which are recommended). Add the query +string parameter ``binary_prefix=true`` to the URL to repair this warning:: + + # mysqlclient + engine = create_engine( + "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") + + # PyMySQL + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") + + +The ``binary_prefix`` flag may or may not be supported by other MySQL drivers. + +SQLAlchemy itself cannot render this ``_binary`` prefix reliably, as it does +not work with the NULL value, which is valid to be sent as a bound parameter. +As the MySQL driver renders parameters directly into the SQL string, it's the +most efficient place for this additional keyword to be passed. + +.. seealso:: + + `Character set introducers <https://dev.mysql.com/doc/refman/5.7/en/charset-introducer.html>`_ - on the MySQL website + + +ANSI Quoting Style +------------------ + +MySQL / MariaDB feature two varieties of identifier "quoting style", one using +backticks and the other using quotes, e.g. ```some_identifier``` vs. +``"some_identifier"``. All MySQL dialects detect which version +is in use by checking the value of :ref:`sql_mode<mysql_sql_mode>` when a connection is first +established with a particular :class:`_engine.Engine`. +This quoting style comes +into play when rendering table and column names as well as when reflecting +existing database structures. The detection is entirely automatic and +no special configuration is needed to use either quoting style. + + +.. _mysql_sql_mode: + +Changing the sql_mode +--------------------- + +MySQL supports operating in multiple +`Server SQL Modes <https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html>`_ for +both Servers and Clients. To change the ``sql_mode`` for a given application, a +developer can leverage SQLAlchemy's Events system. + +In the following example, the event system is used to set the ``sql_mode`` on +the ``first_connect`` and ``connect`` events:: + + from sqlalchemy import create_engine, event + + eng = create_engine("mysql://scott:tiger@localhost/test", echo='debug') + + # `insert=True` will ensure this is the very first listener to run + @event.listens_for(eng, "connect", insert=True) + def connect(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("SET sql_mode = 'STRICT_ALL_TABLES'") + + conn = eng.connect() + +In the example illustrated above, the "connect" event will invoke the "SET" +statement on the connection at the moment a particular DBAPI connection is +first created for a given Pool, before the connection is made available to the +connection pool. Additionally, because the function was registered with +``insert=True``, it will be prepended to the internal list of registered +functions. + + +MySQL / MariaDB SQL Extensions +------------------------------ + +Many of the MySQL / MariaDB SQL extensions are handled through SQLAlchemy's generic +function and operator support:: + + table.select(table.c.password==func.md5('plaintext')) + table.select(table.c.username.op('regexp')('^[a-d]')) + +And of course any valid SQL statement can be executed as a string as well. + +Some limited direct support for MySQL / MariaDB extensions to SQL is currently +available. + +* INSERT..ON DUPLICATE KEY UPDATE: See + :ref:`mysql_insert_on_duplicate_key_update` + +* SELECT pragma, use :meth:`_expression.Select.prefix_with` and + :meth:`_query.Query.prefix_with`:: + + select(...).prefix_with(['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) + +* UPDATE with LIMIT:: + + update(..., mysql_limit=10, mariadb_limit=10) + +* optimizer hints, use :meth:`_expression.Select.prefix_with` and + :meth:`_query.Query.prefix_with`:: + + select(...).prefix_with("/*+ NO_RANGE_OPTIMIZATION(t4 PRIMARY) */") + +* index hints, use :meth:`_expression.Select.with_hint` and + :meth:`_query.Query.with_hint`:: + + select(...).with_hint(some_table, "USE INDEX xyz") + +* MATCH operator support:: + + from sqlalchemy.dialects.mysql import match + select(...).where(match(col1, col2, against="some expr").in_boolean_mode()) + + .. seealso:: + + :class:`_mysql.match` + +.. _mysql_insert_on_duplicate_key_update: + +INSERT...ON DUPLICATE KEY UPDATE (Upsert) +------------------------------------------ + +MySQL / MariaDB allow "upserts" (update or insert) +of rows into a table via the ``ON DUPLICATE KEY UPDATE`` clause of the +``INSERT`` statement. A candidate row will only be inserted if that row does +not match an existing primary or unique key in the table; otherwise, an UPDATE +will be performed. The statement allows for separate specification of the +values to INSERT versus the values for UPDATE. + +SQLAlchemy provides ``ON DUPLICATE KEY UPDATE`` support via the MySQL-specific +:func:`.mysql.insert()` function, which provides +the generative method :meth:`~.mysql.Insert.on_duplicate_key_update`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.mysql import insert + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... data=insert_stmt.inserted.data, + ... status='U' + ... ) + >>> print(on_duplicate_key_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = VALUES(data), status = %s + + +Unlike PostgreSQL's "ON CONFLICT" phrase, the "ON DUPLICATE KEY UPDATE" +phrase will always match on any primary key or unique key, and will always +perform an UPDATE if there's a match; there are no options for it to raise +an error or to skip performing an UPDATE. + +``ON DUPLICATE KEY UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are normally specified using +keyword arguments passed to the +:meth:`_mysql.Insert.on_duplicate_key_update` +given column key values (usually the name of the column, unless it +specifies :paramref:`_schema.Column.key` +) as keys and literal or SQL expressions +as values: + +.. sourcecode:: pycon+sql + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... data="some data", + ... updated_at=func.current_timestamp(), + ... ) + + >>> print(on_duplicate_key_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP + +In a manner similar to that of :meth:`.UpdateBase.values`, other parameter +forms are accepted, including a single dictionary: + +.. sourcecode:: pycon+sql + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... {"data": "some data", "updated_at": func.current_timestamp()}, + ... ) + +as well as a list of 2-tuples, which will automatically provide +a parameter-ordered UPDATE statement in a manner similar to that described +at :ref:`tutorial_parameter_ordered_updates`. Unlike the :class:`_expression.Update` +object, +no special flag is needed to specify the intent since the argument form is +this context is unambiguous: + +.. sourcecode:: pycon+sql + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... [ + ... ("data", "some data"), + ... ("updated_at", func.current_timestamp()), + ... ] + ... ) + + >>> print(on_duplicate_key_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP + +.. versionchanged:: 1.3 support for parameter-ordered UPDATE clause within + MySQL ON DUPLICATE KEY UPDATE + +.. warning:: + + The :meth:`_mysql.Insert.on_duplicate_key_update` + method does **not** take into + account Python-side default UPDATE values or generation functions, e.g. + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON DUPLICATE KEY style of UPDATE, + unless they are manually specified explicitly in the parameters. + + + +In order to refer to the proposed insertion row, the special alias +:attr:`_mysql.Insert.inserted` is available as an attribute on +the :class:`_mysql.Insert` object; this object is a +:class:`_expression.ColumnCollection` which contains all columns of the target +table: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh') + + >>> do_update_stmt = stmt.on_duplicate_key_update( + ... data="updated value", + ... author=stmt.inserted.author + ... ) + + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data, author) VALUES (%s, %s, %s) + ON DUPLICATE KEY UPDATE data = %s, author = VALUES(author) + +When rendered, the "inserted" namespace will produce the expression +``VALUES(<columnname>)``. + +.. versionadded:: 1.2 Added support for MySQL ON DUPLICATE KEY UPDATE clause + + + +rowcount Support +---------------- + +SQLAlchemy standardizes the DBAPI ``cursor.rowcount`` attribute to be the +usual definition of "number of rows matched by an UPDATE or DELETE" statement. +This is in contradiction to the default setting on most MySQL DBAPI drivers, +which is "number of rows actually modified/deleted". For this reason, the +SQLAlchemy MySQL dialects always add the ``constants.CLIENT.FOUND_ROWS`` +flag, or whatever is equivalent for the target dialect, upon connection. +This setting is currently hardcoded. + +.. seealso:: + + :attr:`_engine.CursorResult.rowcount` + + +.. _mysql_indexes: + +MySQL / MariaDB- Specific Index Options +----------------------------------------- + +MySQL and MariaDB-specific extensions to the :class:`.Index` construct are available. + +Index Length +~~~~~~~~~~~~~ + +MySQL and MariaDB both provide an option to create index entries with a certain length, where +"length" refers to the number of characters or bytes in each value which will +become part of the index. SQLAlchemy provides this feature via the +``mysql_length`` and/or ``mariadb_length`` parameters:: + + Index('my_index', my_table.c.data, mysql_length=10, mariadb_length=10) + + Index('a_b_idx', my_table.c.a, my_table.c.b, mysql_length={'a': 4, + 'b': 9}) + + Index('a_b_idx', my_table.c.a, my_table.c.b, mariadb_length={'a': 4, + 'b': 9}) + +Prefix lengths are given in characters for nonbinary string types and in bytes +for binary string types. The value passed to the keyword argument *must* be +either an integer (and, thus, specify the same prefix length value for all +columns of the index) or a dict in which keys are column names and values are +prefix length values for corresponding columns. MySQL and MariaDB only allow a +length for a column of an index if it is for a CHAR, VARCHAR, TEXT, BINARY, +VARBINARY and BLOB. + +Index Prefixes +~~~~~~~~~~~~~~ + +MySQL storage engines permit you to specify an index prefix when creating +an index. SQLAlchemy provides this feature via the +``mysql_prefix`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_prefix='FULLTEXT') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX, so it *must* be a valid index prefix for your MySQL +storage engine. + +.. versionadded:: 1.1.5 + +.. seealso:: + + `CREATE INDEX <https://dev.mysql.com/doc/refman/5.0/en/create-index.html>`_ - MySQL documentation + +Index Types +~~~~~~~~~~~~~ + +Some MySQL storage engines permit you to specify an index type when creating +an index or primary key constraint. SQLAlchemy provides this feature via the +``mysql_using`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_using='hash', mariadb_using='hash') + +As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`:: + + PrimaryKeyConstraint("data", mysql_using='hash', mariadb_using='hash') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index +type for your MySQL storage engine. + +More information can be found at: + +https://dev.mysql.com/doc/refman/5.0/en/create-index.html + +https://dev.mysql.com/doc/refman/5.0/en/create-table.html + +Index Parsers +~~~~~~~~~~~~~ + +CREATE FULLTEXT INDEX in MySQL also supports a "WITH PARSER" option. This +is available using the keyword argument ``mysql_with_parser``:: + + Index( + 'my_index', my_table.c.data, + mysql_prefix='FULLTEXT', mysql_with_parser="ngram", + mariadb_prefix='FULLTEXT', mariadb_with_parser="ngram", + ) + +.. versionadded:: 1.3 + + +.. _mysql_foreign_keys: + +MySQL / MariaDB Foreign Keys +----------------------------- + +MySQL and MariaDB's behavior regarding foreign keys has some important caveats. + +Foreign Key Arguments to Avoid +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Neither MySQL nor MariaDB support the foreign key arguments "DEFERRABLE", "INITIALLY", +or "MATCH". Using the ``deferrable`` or ``initially`` keyword argument with +:class:`_schema.ForeignKeyConstraint` or :class:`_schema.ForeignKey` +will have the effect of +these keywords being rendered in a DDL expression, which will then raise an +error on MySQL or MariaDB. In order to use these keywords on a foreign key while having +them ignored on a MySQL / MariaDB backend, use a custom compile rule:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.schema import ForeignKeyConstraint + + @compiles(ForeignKeyConstraint, "mysql", "mariadb") + def process(element, compiler, **kw): + element.deferrable = element.initially = None + return compiler.visit_foreign_key_constraint(element, **kw) + +The "MATCH" keyword is in fact more insidious, and is explicitly disallowed +by SQLAlchemy in conjunction with the MySQL or MariaDB backends. This argument is +silently ignored by MySQL / MariaDB, but in addition has the effect of ON UPDATE and ON +DELETE options also being ignored by the backend. Therefore MATCH should +never be used with the MySQL / MariaDB backends; as is the case with DEFERRABLE and +INITIALLY, custom compilation rules can be used to correct a +ForeignKeyConstraint at DDL definition time. + +Reflection of Foreign Key Constraints +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Not all MySQL / MariaDB storage engines support foreign keys. When using the +very common ``MyISAM`` MySQL storage engine, the information loaded by table +reflection will not include foreign keys. For these tables, you may supply a +:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: + + Table('mytable', metadata, + ForeignKeyConstraint(['other_id'], ['othertable.other_id']), + autoload_with=engine + ) + +.. seealso:: + + :ref:`mysql_storage_engines` + +.. _mysql_unique_constraints: + +MySQL / MariaDB Unique Constraints and Reflection +---------------------------------------------------- + +SQLAlchemy supports both the :class:`.Index` construct with the +flag ``unique=True``, indicating a UNIQUE index, as well as the +:class:`.UniqueConstraint` construct, representing a UNIQUE constraint. +Both objects/syntaxes are supported by MySQL / MariaDB when emitting DDL to create +these constraints. However, MySQL / MariaDB does not have a unique constraint +construct that is separate from a unique index; that is, the "UNIQUE" +constraint on MySQL / MariaDB is equivalent to creating a "UNIQUE INDEX". + +When reflecting these constructs, the +:meth:`_reflection.Inspector.get_indexes` +and the :meth:`_reflection.Inspector.get_unique_constraints` +methods will **both** +return an entry for a UNIQUE index in MySQL / MariaDB. However, when performing +full table reflection using ``Table(..., autoload_with=engine)``, +the :class:`.UniqueConstraint` construct is +**not** part of the fully reflected :class:`_schema.Table` construct under any +circumstances; this construct is always represented by a :class:`.Index` +with the ``unique=True`` setting present in the :attr:`_schema.Table.indexes` +collection. + + +TIMESTAMP / DATETIME issues +--------------------------- + +.. _mysql_timestamp_onupdate: + +Rendering ON UPDATE CURRENT TIMESTAMP for MySQL / MariaDB's explicit_defaults_for_timestamp +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL / MariaDB have historically expanded the DDL for the :class:`_types.TIMESTAMP` +datatype into the phrase "TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE +CURRENT_TIMESTAMP", which includes non-standard SQL that automatically updates +the column with the current timestamp when an UPDATE occurs, eliminating the +usual need to use a trigger in such a case where server-side update changes are +desired. + +MySQL 5.6 introduced a new flag `explicit_defaults_for_timestamp +<https://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html +#sysvar_explicit_defaults_for_timestamp>`_ which disables the above behavior, +and in MySQL 8 this flag defaults to true, meaning in order to get a MySQL +"on update timestamp" without changing this flag, the above DDL must be +rendered explicitly. Additionally, the same DDL is valid for use of the +``DATETIME`` datatype as well. + +SQLAlchemy's MySQL dialect does not yet have an option to generate +MySQL's "ON UPDATE CURRENT_TIMESTAMP" clause, noting that this is not a general +purpose "ON UPDATE" as there is no such syntax in standard SQL. SQLAlchemy's +:paramref:`_schema.Column.server_onupdate` parameter is currently not related +to this special MySQL behavior. + +To generate this DDL, make use of the :paramref:`_schema.Column.server_default` +parameter and pass a textual clause that also includes the ON UPDATE clause:: + + from sqlalchemy import Table, MetaData, Column, Integer, String, TIMESTAMP + from sqlalchemy import text + + metadata = MetaData() + + mytable = Table( + "mytable", + metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column( + 'last_updated', + TIMESTAMP, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ) + ) + +The same instructions apply to use of the :class:`_types.DateTime` and +:class:`_types.DATETIME` datatypes:: + + from sqlalchemy import DateTime + + mytable = Table( + "mytable", + metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column( + 'last_updated', + DateTime, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ) + ) + + +Even though the :paramref:`_schema.Column.server_onupdate` feature does not +generate this DDL, it still may be desirable to signal to the ORM that this +updated value should be fetched. This syntax looks like the following:: + + from sqlalchemy.schema import FetchedValue + + class MyClass(Base): + __tablename__ = 'mytable' + + id = Column(Integer, primary_key=True) + data = Column(String(50)) + last_updated = Column( + TIMESTAMP, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), + server_onupdate=FetchedValue() + ) + + +.. _mysql_timestamp_null: + +TIMESTAMP Columns and NULL +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL historically enforces that a column which specifies the +TIMESTAMP datatype implicitly includes a default value of +CURRENT_TIMESTAMP, even though this is not stated, and additionally +sets the column as NOT NULL, the opposite behavior vs. that of all +other datatypes:: + + mysql> CREATE TABLE ts_test ( + -> a INTEGER, + -> b INTEGER NOT NULL, + -> c TIMESTAMP, + -> d TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + -> e TIMESTAMP NULL); + Query OK, 0 rows affected (0.03 sec) + + mysql> SHOW CREATE TABLE ts_test; + +---------+----------------------------------------------------- + | Table | Create Table + +---------+----------------------------------------------------- + | ts_test | CREATE TABLE `ts_test` ( + `a` int(11) DEFAULT NULL, + `b` int(11) NOT NULL, + `c` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `d` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + `e` timestamp NULL DEFAULT NULL + ) ENGINE=MyISAM DEFAULT CHARSET=latin1 + +Above, we see that an INTEGER column defaults to NULL, unless it is specified +with NOT NULL. But when the column is of type TIMESTAMP, an implicit +default of CURRENT_TIMESTAMP is generated which also coerces the column +to be a NOT NULL, even though we did not specify it as such. + +This behavior of MySQL can be changed on the MySQL side using the +`explicit_defaults_for_timestamp +<https://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html +#sysvar_explicit_defaults_for_timestamp>`_ configuration flag introduced in +MySQL 5.6. With this server setting enabled, TIMESTAMP columns behave like +any other datatype on the MySQL side with regards to defaults and nullability. + +However, to accommodate the vast majority of MySQL databases that do not +specify this new flag, SQLAlchemy emits the "NULL" specifier explicitly with +any TIMESTAMP column that does not specify ``nullable=False``. In order to +accommodate newer databases that specify ``explicit_defaults_for_timestamp``, +SQLAlchemy also emits NOT NULL for TIMESTAMP columns that do specify +``nullable=False``. The following example illustrates:: + + from sqlalchemy import MetaData, Integer, Table, Column, text + from sqlalchemy.dialects.mysql import TIMESTAMP + + m = MetaData() + t = Table('ts_test', m, + Column('a', Integer), + Column('b', Integer, nullable=False), + Column('c', TIMESTAMP), + Column('d', TIMESTAMP, nullable=False) + ) + + + from sqlalchemy import create_engine + e = create_engine("mysql://scott:tiger@localhost/test", echo=True) + m.create_all(e) + +output:: + + CREATE TABLE ts_test ( + a INTEGER, + b INTEGER NOT NULL, + c TIMESTAMP NULL, + d TIMESTAMP NOT NULL + ) + +.. versionchanged:: 1.0.0 - SQLAlchemy now renders NULL or NOT NULL in all + cases for TIMESTAMP columns, to accommodate + ``explicit_defaults_for_timestamp``. Prior to this version, it will + not render "NOT NULL" for a TIMESTAMP column that is ``nullable=False``. + +""" # noqa + +from array import array as _array +from collections import defaultdict +from itertools import compress +import re + +from sqlalchemy import literal_column +from sqlalchemy import text +from sqlalchemy.sql import visitors +from . import reflection as _reflection +from .enumerated import ENUM +from .enumerated import SET +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from .reserved_words import RESERVED_WORDS_MARIADB +from .reserved_words import RESERVED_WORDS_MYSQL +from .types import _FloatType +from .types import _IntegerType +from .types import _MatchType +from .types import _NumericType +from .types import _StringType +from .types import BIGINT +from .types import BIT +from .types import CHAR +from .types import DATETIME +from .types import DECIMAL +from .types import DOUBLE +from .types import FLOAT +from .types import INTEGER +from .types import LONGBLOB +from .types import LONGTEXT +from .types import MEDIUMBLOB +from .types import MEDIUMINT +from .types import MEDIUMTEXT +from .types import NCHAR +from .types import NUMERIC +from .types import NVARCHAR +from .types import REAL +from .types import SMALLINT +from .types import TEXT +from .types import TIME +from .types import TIMESTAMP +from .types import TINYBLOB +from .types import TINYINT +from .types import TINYTEXT +from .types import VARCHAR +from .types import YEAR +from ... import exc +from ... import log +from ... import schema as sa_schema +from ... import sql +from ... import types as sqltypes +from ... import util +from ...engine import default +from ...engine import reflection +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import functions +from ...sql import operators +from ...sql import roles +from ...sql import util as sql_util +from ...sql.sqltypes import Unicode +from ...types import BINARY +from ...types import BLOB +from ...types import BOOLEAN +from ...types import DATE +from ...types import VARBINARY +from ...util import topological + +AUTOCOMMIT_RE = re.compile( + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)", + re.I | re.UNICODE, +) +SET_RE = re.compile( + r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE +) + + +# old names +MSTime = TIME +MSSet = SET +MSEnum = ENUM +MSLongBlob = LONGBLOB +MSMediumBlob = MEDIUMBLOB +MSTinyBlob = TINYBLOB +MSBlob = BLOB +MSBinary = BINARY +MSVarBinary = VARBINARY +MSNChar = NCHAR +MSNVarChar = NVARCHAR +MSChar = CHAR +MSString = VARCHAR +MSLongText = LONGTEXT +MSMediumText = MEDIUMTEXT +MSTinyText = TINYTEXT +MSText = TEXT +MSYear = YEAR +MSTimeStamp = TIMESTAMP +MSBit = BIT +MSSmallInteger = SMALLINT +MSTinyInteger = TINYINT +MSMediumInteger = MEDIUMINT +MSBigInteger = BIGINT +MSNumeric = NUMERIC +MSDecimal = DECIMAL +MSDouble = DOUBLE +MSReal = REAL +MSFloat = FLOAT +MSInteger = INTEGER + +colspecs = { + _IntegerType: _IntegerType, + _NumericType: _NumericType, + _FloatType: _FloatType, + sqltypes.Numeric: NUMERIC, + sqltypes.Float: FLOAT, + sqltypes.Time: TIME, + sqltypes.Enum: ENUM, + sqltypes.MatchType: _MatchType, + sqltypes.JSON: JSON, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, +} + +# Everything 3.23 through 5.1 excepting OpenGIS types. +ischema_names = { + "bigint": BIGINT, + "binary": BINARY, + "bit": BIT, + "blob": BLOB, + "boolean": BOOLEAN, + "char": CHAR, + "date": DATE, + "datetime": DATETIME, + "decimal": DECIMAL, + "double": DOUBLE, + "enum": ENUM, + "fixed": DECIMAL, + "float": FLOAT, + "int": INTEGER, + "integer": INTEGER, + "json": JSON, + "longblob": LONGBLOB, + "longtext": LONGTEXT, + "mediumblob": MEDIUMBLOB, + "mediumint": MEDIUMINT, + "mediumtext": MEDIUMTEXT, + "nchar": NCHAR, + "nvarchar": NVARCHAR, + "numeric": NUMERIC, + "set": SET, + "smallint": SMALLINT, + "text": TEXT, + "time": TIME, + "timestamp": TIMESTAMP, + "tinyblob": TINYBLOB, + "tinyint": TINYINT, + "tinytext": TINYTEXT, + "varbinary": VARBINARY, + "varchar": VARCHAR, + "year": YEAR, +} + + +class MySQLExecutionContext(default.DefaultExecutionContext): + def should_autocommit_text(self, statement): + return AUTOCOMMIT_RE.match(statement) + + def create_server_side_cursor(self): + if self.dialect.supports_server_side_cursors: + return self._dbapi_connection.cursor(self.dialect._sscursor) + else: + raise NotImplementedError() + + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "select nextval(%s)" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + +class MySQLCompiler(compiler.SQLCompiler): + + render_table_with_column_in_update_from = True + """Overridden from base SQLCompiler value""" + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update({"milliseconds": "millisecond"}) + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, + and no ``FROM`` clause is to be appended. + + """ + if self.stack: + stmt = self.stack[-1]["selectable"] + if stmt._where_criteria: + return " FROM DUAL" + + return "" + + def visit_random_func(self, fn, **kw): + return "rand%s" % self.function_argspec(fn) + + def visit_sequence(self, seq, **kw): + return "nextval(%s)" % self.preparer.format_sequence(seq) + + def visit_sysdate_func(self, fn, **kw): + return "SYSDATE()" + + def _render_json_extract_from_binary(self, binary, operator, **kw): + # note we are intentionally calling upon the process() calls in the + # order in which they appear in the SQL String as this is used + # by positional parameter rendering + + if binary.type._type_affinity is sqltypes.JSON: + return "JSON_EXTRACT(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + # for non-JSON, MySQL doesn't handle JSON null at all so it has to + # be explicit + case_expression = "CASE JSON_EXTRACT(%s, %s) WHEN 'null' THEN NULL" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + if binary.type._type_affinity is sqltypes.Integer: + type_expression = ( + "ELSE CAST(JSON_EXTRACT(%s, %s) AS SIGNED INTEGER)" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ) + elif binary.type._type_affinity is sqltypes.Numeric: + if ( + binary.type.scale is not None + and binary.type.precision is not None + ): + # using DECIMAL here because MySQL does not recognize NUMERIC + type_expression = ( + "ELSE CAST(JSON_EXTRACT(%s, %s) AS DECIMAL(%s, %s))" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + binary.type.precision, + binary.type.scale, + ) + ) + else: + # FLOAT / REAL not added in MySQL til 8.0.17 + type_expression = ( + "ELSE JSON_EXTRACT(%s, %s)+0.0000000000000000000000" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ) + elif binary.type._type_affinity is sqltypes.Boolean: + # the NULL handling is particularly weird with boolean, so + # explicitly return true/false constants + type_expression = "WHEN true THEN true ELSE false" + elif binary.type._type_affinity is sqltypes.String: + # (gord): this fails with a JSON value that's a four byte unicode + # string. SQLite has the same problem at the moment + # (zzzeek): I'm not really sure. let's take a look at a test case + # that hits each backend and maybe make a requires rule for it? + type_expression = "ELSE JSON_UNQUOTE(JSON_EXTRACT(%s, %s))" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + else: + # other affinity....this is not expected right now + type_expression = "ELSE JSON_EXTRACT(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + return case_expression + " " + type_expression + " END" + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_on_duplicate_key_update(self, on_duplicate, **kw): + statement = self.current_executable + + if on_duplicate._parameter_ordering: + parameter_ordering = [ + coercions.expect(roles.DMLColumnRole, key) + for key in on_duplicate._parameter_ordering + ] + ordered_keys = set(parameter_ordering) + cols = [ + statement.table.c[key] + for key in parameter_ordering + if key in statement.table.c + ] + [c for c in statement.table.c if c.key not in ordered_keys] + else: + cols = statement.table.c + + clauses = [] + # traverses through all table columns to preserve table column order + for column in (col for col in cols if col.key in on_duplicate.update): + + val = on_duplicate.update[column.key] + + if coercions._is_literal(val): + val = elements.BindParameter(None, val, type_=column.type) + value_text = self.process(val.self_group(), use_schema=False) + else: + + def replace(obj): + if ( + isinstance(obj, elements.BindParameter) + and obj.type._isnull + ): + obj = obj._clone() + obj.type = column.type + return obj + elif ( + isinstance(obj, elements.ColumnClause) + and obj.table is on_duplicate.inserted_alias + ): + obj = literal_column( + "VALUES(" + self.preparer.quote(obj.name) + ")" + ) + return obj + else: + # element is not replaced + return None + + val = visitors.replacement_traverse(val, {}, replace) + value_text = self.process(val.self_group(), use_schema=False) + + name_text = self.preparer.quote(column.name) + clauses.append("%s = %s" % (name_text, value_text)) + + non_matching = set(on_duplicate.update) - set(c.key for c in cols) + if non_matching: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.statement.table.name, + (", ".join("'%s'" % c for c in non_matching)), + ) + ) + + return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses) + + def visit_concat_op_binary(self, binary, operator, **kw): + return "concat(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + _match_valid_flag_combinations = frozenset( + ( + # (boolean_mode, natural_language, query_expansion) + (False, False, False), + (True, False, False), + (False, True, False), + (False, False, True), + (False, True, True), + ) + ) + + _match_flag_expressions = ( + "IN BOOLEAN MODE", + "IN NATURAL LANGUAGE MODE", + "WITH QUERY EXPANSION", + ) + + def visit_mysql_match(self, element, **kw): + return self.visit_match_op_binary(element, element.operator, **kw) + + def visit_match_op_binary(self, binary, operator, **kw): + """ + Note that `mysql_boolean_mode` is enabled by default because of + backward compatibility + """ + + modifiers = binary.modifiers + + boolean_mode = modifiers.get("mysql_boolean_mode", True) + natural_language = modifiers.get("mysql_natural_language", False) + query_expansion = modifiers.get("mysql_query_expansion", False) + + flag_combination = (boolean_mode, natural_language, query_expansion) + + if flag_combination not in self._match_valid_flag_combinations: + flags = ( + "in_boolean_mode=%s" % boolean_mode, + "in_natural_language_mode=%s" % natural_language, + "with_query_expansion=%s" % query_expansion, + ) + + flags = ", ".join(flags) + + raise exc.CompileError("Invalid MySQL match flags: %s" % flags) + + match_clause = binary.left + match_clause = self.process(match_clause, **kw) + against_clause = self.process(binary.right, **kw) + + if any(flag_combination): + flag_expressions = compress( + self._match_flag_expressions, + flag_combination, + ) + + against_clause = [against_clause] + against_clause.extend(flag_expressions) + + against_clause = " ".join(against_clause) + + return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause) + + def get_from_hint_text(self, table, text): + return text + + def visit_typeclause(self, typeclause, type_=None, **kw): + if type_ is None: + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.TypeDecorator): + return self.visit_typeclause(typeclause, type_.impl, **kw) + elif isinstance(type_, sqltypes.Integer): + if getattr(type_, "unsigned", False): + return "UNSIGNED INTEGER" + else: + return "SIGNED INTEGER" + elif isinstance(type_, sqltypes.TIMESTAMP): + return "DATETIME" + elif isinstance( + type_, + ( + sqltypes.DECIMAL, + sqltypes.DateTime, + sqltypes.Date, + sqltypes.Time, + ), + ): + return self.dialect.type_compiler.process(type_) + elif isinstance(type_, sqltypes.String) and not isinstance( + type_, (ENUM, SET) + ): + adapted = CHAR._adapt_string_for_cast(type_) + return self.dialect.type_compiler.process(adapted) + elif isinstance(type_, sqltypes._Binary): + return "BINARY" + elif isinstance(type_, sqltypes.JSON): + return "JSON" + elif isinstance(type_, sqltypes.NUMERIC): + return self.dialect.type_compiler.process(type_).replace( + "NUMERIC", "DECIMAL" + ) + elif ( + isinstance(type_, sqltypes.Float) + and self.dialect._support_float_cast + ): + return self.dialect.type_compiler.process(type_) + else: + return None + + def visit_cast(self, cast, **kw): + type_ = self.process(cast.typeclause) + if type_ is None: + util.warn( + "Datatype %s does not support CAST on MySQL/MariaDb; " + "the CAST will be skipped." + % self.dialect.type_compiler.process(cast.typeclause.type) + ) + return self.process(cast.clause.self_group(), **kw) + + return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) + + def render_literal_value(self, value, type_): + value = super(MySQLCompiler, self).render_literal_value(value, type_) + if self.dialect._backslash_escapes: + value = value.replace("\\", "\\\\") + return value + + # override native_boolean=False behavior here, as + # MySQL still supports native boolean + def visit_true(self, element, **kw): + return "true" + + def visit_false(self, element, **kw): + return "false" + + def get_select_precolumns(self, select, **kw): + """Add special MySQL keywords in place of DISTINCT. + + .. deprecated 1.4:: this usage is deprecated. + :meth:`_expression.Select.prefix_with` should be used for special + keywords at the start of a SELECT. + + """ + if isinstance(select._distinct, util.string_types): + util.warn_deprecated( + "Sending string values for 'distinct' is deprecated in the " + "MySQL dialect and will be removed in a future release. " + "Please use :meth:`.Select.prefix_with` for special keywords " + "at the start of a SELECT statement", + version="1.4", + ) + return select._distinct.upper() + " " + + return super(MySQLCompiler, self).get_select_precolumns(select, **kw) + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " INNER JOIN " + + return "".join( + ( + self.process( + join.left, asfrom=True, from_linter=from_linter, **kwargs + ), + join_type, + self.process( + join.right, asfrom=True, from_linter=from_linter, **kwargs + ), + " ON ", + self.process(join.onclause, from_linter=from_linter, **kwargs), + ) + ) + + def for_update_clause(self, select, **kw): + if select._for_update_arg.read: + tmp = " LOCK IN SHARE MODE" + else: + tmp = " FOR UPDATE" + + if select._for_update_arg.of and self.dialect.supports_for_update_of: + + tables = util.OrderedSet() + for c in select._for_update_arg.of: + tables.update(sql_util.surface_selectables_only(c)) + + tmp += " OF " + ", ".join( + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def limit_clause(self, select, **kw): + # MySQL supports: + # LIMIT <limit> + # LIMIT <offset>, <limit> + # and in server versions > 3.3: + # LIMIT <limit> OFFSET <offset> + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. + + limit_clause, offset_clause = ( + select._limit_clause, + select._offset_clause, + ) + + if limit_clause is None and offset_clause is None: + return "" + elif offset_clause is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + # https://dev.mysql.com/doc/refman/5.0/en/select.html + if limit_clause is None: + # hardwire the upper limit. Currently + # needed by OurSQL with Python 3 + # (https://bugs.launchpad.net/oursql/+bug/686232), + # but also is consistent with the usage of the upper + # bound as part of MySQL's "syntax" for OFFSET with + # no LIMIT + return " \n LIMIT %s, %s" % ( + self.process(offset_clause, **kw), + "18446744073709551615", + ) + else: + return " \n LIMIT %s, %s" % ( + self.process(offset_clause, **kw), + self.process(limit_clause, **kw), + ) + else: + # No offset provided, so just use the limit + return " \n LIMIT %s" % (self.process(limit_clause, **kw),) + + def update_limit_clause(self, update_stmt): + limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) + if limit: + return "LIMIT %s" % limit + else: + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + kw["asfrom"] = True + return ", ".join( + t._compiler_dispatch(self, **kw) + for t in [from_table] + list(extra_froms) + ) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return None + + def delete_table_clause(self, delete_stmt, from_table, extra_froms): + """If we have extra froms make sure we render any alias as hint.""" + ashint = False + if extra_froms: + ashint = True + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, ashint=ashint + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. USING clause specific to MySQL.""" + kw["asfrom"] = True + return "USING " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def visit_empty_set_expr(self, element_types): + return ( + "SELECT %(outer)s FROM (SELECT %(inner)s) " + "as _empty_set WHERE 1!=1" + % { + "inner": ", ".join( + "1 AS _in_%s" % idx + for idx, type_ in enumerate(element_types) + ), + "outer": ", ".join( + "_in_%s" % idx for idx, type_ in enumerate(element_types) + ), + } + ) + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "NOT (%s <=> %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "%s <=> %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def _mariadb_regexp_flags(self, flags, pattern, **kw): + return "CONCAT('(?', %s, ')', %s)" % ( + self.process(flags, **kw), + self.process(pattern, **kw), + ) + + def _regexp_match(self, op_string, binary, operator, **kw): + flags = binary.modifiers["flags"] + if flags is None: + return self._generate_generic_binary(binary, op_string, **kw) + elif self.dialect.is_mariadb: + return "%s%s%s" % ( + self.process(binary.left, **kw), + op_string, + self._mariadb_regexp_flags(flags, binary.right), + ) + else: + text = "REGEXP_LIKE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.process(flags, **kw), + ) + if op_string == " NOT REGEXP ": + return "NOT %s" % text + else: + return text + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match(" REGEXP ", binary, operator, **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match(" NOT REGEXP ", binary, operator, **kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + flags = binary.modifiers["flags"] + replacement = binary.modifiers["replacement"] + if flags is None: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.process(replacement, **kw), + ) + elif self.dialect.is_mariadb: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self._mariadb_regexp_flags(flags, binary.right), + self.process(replacement, **kw), + ) + else: + return "REGEXP_REPLACE(%s, %s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.process(replacement, **kw), + self.process(flags, **kw), + ) + + +class MySQLDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kw): + """Builds column DDL.""" + + colspec = [ + self.preparer.format_column(column), + self.dialect.type_compiler.process( + column.type, type_expression=column + ), + ] + + if column.computed is not None: + colspec.append(self.process(column.computed)) + + is_timestamp = isinstance( + column.type._unwrapped_dialect_impl(self.dialect), + sqltypes.TIMESTAMP, + ) + + if not column.nullable: + colspec.append("NOT NULL") + + # see: https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa + elif column.nullable and is_timestamp: + colspec.append("NULL") + + comment = column.comment + if comment is not None: + literal = self.sql_compiler.render_literal_value( + comment, sqltypes.String() + ) + colspec.append("COMMENT " + literal) + + if ( + column.table is not None + and column is column.table._autoincrement_column + and ( + column.server_default is None + or isinstance(column.server_default, sa_schema.Identity) + ) + and not ( + self.dialect.supports_sequences + and isinstance(column.default, sa_schema.Sequence) + and not column.default.optional + ) + ): + colspec.append("AUTO_INCREMENT") + else: + default = self.get_column_default_string(column) + if default is not None: + colspec.append("DEFAULT " + default) + return " ".join(colspec) + + def post_create_table(self, table): + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] + + opts = dict( + (k[len(self.dialect.name) + 1 :].upper(), v) + for k, v in table.kwargs.items() + if k.startswith("%s_" % self.dialect.name) + ) + + if table.comment is not None: + opts["COMMENT"] = table.comment + + partition_options = [ + "PARTITION_BY", + "PARTITIONS", + "SUBPARTITIONS", + "SUBPARTITION_BY", + ] + + nonpart_options = set(opts).difference(partition_options) + part_options = set(opts).intersection(partition_options) + + for opt in topological.sort( + [ + ("DEFAULT_CHARSET", "COLLATE"), + ("DEFAULT_CHARACTER_SET", "COLLATE"), + ("CHARSET", "COLLATE"), + ("CHARACTER_SET", "COLLATE"), + ], + nonpart_options, + ): + arg = opts[opt] + if opt in _reflection._options_of_type_string: + + arg = self.sql_compiler.render_literal_value( + arg, sqltypes.String() + ) + + if opt in ( + "DATA_DIRECTORY", + "INDEX_DIRECTORY", + "DEFAULT_CHARACTER_SET", + "CHARACTER_SET", + "DEFAULT_CHARSET", + "DEFAULT_COLLATE", + ): + opt = opt.replace("_", " ") + + joiner = "=" + if opt in ( + "TABLESPACE", + "DEFAULT CHARACTER SET", + "CHARACTER SET", + "COLLATE", + ): + joiner = " " + + table_opts.append(joiner.join((opt, arg))) + + for opt in topological.sort( + [ + ("PARTITION_BY", "PARTITIONS"), + ("PARTITION_BY", "SUBPARTITION_BY"), + ("PARTITION_BY", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITION_BY"), + ("SUBPARTITION_BY", "SUBPARTITIONS"), + ], + part_options, + ): + arg = opts[opt] + if opt in _reflection._options_of_type_string: + arg = self.sql_compiler.render_literal_value( + arg, sqltypes.String() + ) + + opt = opt.replace("_", " ") + joiner = " " + + table_opts.append(joiner.join((opt, arg))) + + return " ".join(table_opts) + + def visit_create_index(self, create, **kw): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + table = preparer.format_table(index.table) + + columns = [ + self.sql_compiler.process( + elements.Grouping(expr) + if ( + isinstance(expr, elements.BinaryExpression) + or ( + isinstance(expr, elements.UnaryExpression) + and expr.modifier + not in (operators.desc_op, operators.asc_op) + ) + or isinstance(expr, functions.FunctionElement) + ) + else expr, + include_table=False, + literal_binds=True, + ) + for expr in index.expressions + ] + + name = self._prepared_index_name(index) + + text = "CREATE " + if index.unique: + text += "UNIQUE " + + index_prefix = index.kwargs.get("%s_prefix" % self.dialect.name, None) + if index_prefix: + text += index_prefix + " " + + text += "INDEX " + if create.if_not_exists: + text += "IF NOT EXISTS " + text += "%s ON %s " % (name, table) + + length = index.dialect_options[self.dialect.name]["length"] + if length is not None: + + if isinstance(length, dict): + # length value can be a (column_name --> integer value) + # mapping specifying the prefix length for each column of the + # index + columns = ", ".join( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr + ) + for col, expr in zip(index.expressions, columns) + ) + else: + # or can be an integer value specifying the same + # prefix length for all columns of the index + columns = ", ".join( + "%s(%d)" % (col, length) for col in columns + ) + else: + columns = ", ".join(columns) + text += "(%s)" % columns + + parser = index.dialect_options["mysql"]["with_parser"] + if parser is not None: + text += " WITH PARSER %s" % (parser,) + + using = index.dialect_options["mysql"]["using"] + if using is not None: + text += " USING %s" % (preparer.quote(using)) + + return text + + def visit_primary_key_constraint(self, constraint): + text = super(MySQLDDLCompiler, self).visit_primary_key_constraint( + constraint + ) + using = constraint.dialect_options["mysql"]["using"] + if using: + text += " USING %s" % (self.preparer.quote(using)) + return text + + def visit_drop_index(self, drop): + index = drop.element + text = "\nDROP INDEX " + if drop.if_exists: + text += "IF EXISTS " + + return text + "%s ON %s" % ( + self._prepared_index_name(index, include_schema=False), + self.preparer.format_table(index.table), + ) + + def visit_drop_constraint(self, drop): + constraint = drop.element + if isinstance(constraint, sa_schema.ForeignKeyConstraint): + qual = "FOREIGN KEY " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): + qual = "PRIMARY KEY " + const = "" + elif isinstance(constraint, sa_schema.UniqueConstraint): + qual = "INDEX " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.CheckConstraint): + if self.dialect.is_mariadb: + qual = "CONSTRAINT " + else: + qual = "CHECK " + const = self.preparer.format_constraint(constraint) + else: + qual = "" + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) + + def define_constraint_match(self, constraint): + if constraint.match is not None: + raise exc.CompileError( + "MySQL ignores the 'MATCH' keyword while at the same time " + "causes ON UPDATE/ON DELETE clauses to be ignored." + ) + return "" + + def visit_set_table_comment(self, create): + return "ALTER TABLE %s COMMENT %s" % ( + self.preparer.format_table(create.element), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_table_comment(self, create): + return "ALTER TABLE %s COMMENT ''" % ( + self.preparer.format_table(create.element) + ) + + def visit_set_column_comment(self, create): + return "ALTER TABLE %s CHANGE %s %s" % ( + self.preparer.format_table(create.element.table), + self.preparer.format_column(create.element), + self.get_column_specification(create.element), + ) + + +class MySQLTypeCompiler(compiler.GenericTypeCompiler): + def _extend_numeric(self, type_, spec): + "Extend a numeric-type declaration with MySQL specific extensions." + + if not self._mysql_type(type_): + return spec + + if type_.unsigned: + spec += " UNSIGNED" + if type_.zerofill: + spec += " ZEROFILL" + return spec + + def _extend_string(self, type_, defaults, spec): + """Extend a string-type declaration with standard SQL CHARACTER SET / + COLLATE annotations and MySQL specific extensions. + + """ + + def attr(name): + return getattr(type_, name, defaults.get(name)) + + if attr("charset"): + charset = "CHARACTER SET %s" % attr("charset") + elif attr("ascii"): + charset = "ASCII" + elif attr("unicode"): + charset = "UNICODE" + else: + charset = None + + if attr("collation"): + collation = "COLLATE %s" % type_.collation + elif attr("binary"): + collation = "BINARY" + else: + collation = None + + if attr("national"): + # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. + return " ".join( + [c for c in ("NATIONAL", spec, collation) if c is not None] + ) + return " ".join( + [c for c in (spec, charset, collation) if c is not None] + ) + + def _mysql_type(self, type_): + return isinstance(type_, (_StringType, _NumericType)) + + def visit_NUMERIC(self, type_, **kw): + if type_.precision is None: + return self._extend_numeric(type_, "NUMERIC") + elif type_.scale is None: + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s)" % {"precision": type_.precision}, + ) + else: + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + + def visit_DECIMAL(self, type_, **kw): + if type_.precision is None: + return self._extend_numeric(type_, "DECIMAL") + elif type_.scale is None: + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s)" % {"precision": type_.precision}, + ) + else: + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + + def visit_DOUBLE(self, type_, **kw): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric( + type_, + "DOUBLE(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + else: + return self._extend_numeric(type_, "DOUBLE") + + def visit_REAL(self, type_, **kw): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric( + type_, + "REAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + else: + return self._extend_numeric(type_, "REAL") + + def visit_FLOAT(self, type_, **kw): + if ( + self._mysql_type(type_) + and type_.scale is not None + and type_.precision is not None + ): + return self._extend_numeric( + type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale) + ) + elif type_.precision is not None: + return self._extend_numeric( + type_, "FLOAT(%s)" % (type_.precision,) + ) + else: + return self._extend_numeric(type_, "FLOAT") + + def visit_INTEGER(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "INTEGER(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "INTEGER") + + def visit_BIGINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "BIGINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "BIGINT") + + def visit_MEDIUMINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "MEDIUMINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "MEDIUMINT") + + def visit_TINYINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, "TINYINT(%s)" % type_.display_width + ) + else: + return self._extend_numeric(type_, "TINYINT") + + def visit_SMALLINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "SMALLINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "SMALLINT") + + def visit_BIT(self, type_, **kw): + if type_.length is not None: + return "BIT(%s)" % type_.length + else: + return "BIT" + + def visit_DATETIME(self, type_, **kw): + if getattr(type_, "fsp", None): + return "DATETIME(%d)" % type_.fsp + else: + return "DATETIME" + + def visit_DATE(self, type_, **kw): + return "DATE" + + def visit_TIME(self, type_, **kw): + if getattr(type_, "fsp", None): + return "TIME(%d)" % type_.fsp + else: + return "TIME" + + def visit_TIMESTAMP(self, type_, **kw): + if getattr(type_, "fsp", None): + return "TIMESTAMP(%d)" % type_.fsp + else: + return "TIMESTAMP" + + def visit_YEAR(self, type_, **kw): + if type_.display_width is None: + return "YEAR" + else: + return "YEAR(%s)" % type_.display_width + + def visit_TEXT(self, type_, **kw): + if type_.length: + return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "TEXT") + + def visit_TINYTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "TINYTEXT") + + def visit_MEDIUMTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "MEDIUMTEXT") + + def visit_LONGTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "LONGTEXT") + + def visit_VARCHAR(self, type_, **kw): + if type_.length: + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) + else: + raise exc.CompileError( + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) + + def visit_CHAR(self, type_, **kw): + if type_.length: + return self._extend_string( + type_, {}, "CHAR(%(length)s)" % {"length": type_.length} + ) + else: + return self._extend_string(type_, {}, "CHAR") + + def visit_NVARCHAR(self, type_, **kw): + # We'll actually generate the equiv. "NATIONAL VARCHAR" instead + # of "NVARCHAR". + if type_.length: + return self._extend_string( + type_, + {"national": True}, + "VARCHAR(%(length)s)" % {"length": type_.length}, + ) + else: + raise exc.CompileError( + "NVARCHAR requires a length on dialect %s" % self.dialect.name + ) + + def visit_NCHAR(self, type_, **kw): + # We'll actually generate the equiv. + # "NATIONAL CHAR" instead of "NCHAR". + if type_.length: + return self._extend_string( + type_, + {"national": True}, + "CHAR(%(length)s)" % {"length": type_.length}, + ) + else: + return self._extend_string(type_, {"national": True}, "CHAR") + + def visit_VARBINARY(self, type_, **kw): + return "VARBINARY(%d)" % type_.length + + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_) + + def visit_enum(self, type_, **kw): + if not type_.native_enum: + return super(MySQLTypeCompiler, self).visit_enum(type_) + else: + return self._visit_enumerated_values("ENUM", type_, type_.enums) + + def visit_BLOB(self, type_, **kw): + if type_.length: + return "BLOB(%d)" % type_.length + else: + return "BLOB" + + def visit_TINYBLOB(self, type_, **kw): + return "TINYBLOB" + + def visit_MEDIUMBLOB(self, type_, **kw): + return "MEDIUMBLOB" + + def visit_LONGBLOB(self, type_, **kw): + return "LONGBLOB" + + def _visit_enumerated_values(self, name, type_, enumerated_values): + quoted_enums = [] + for e in enumerated_values: + quoted_enums.append("'%s'" % e.replace("'", "''")) + return self._extend_string( + type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) + ) + + def visit_ENUM(self, type_, **kw): + return self._visit_enumerated_values("ENUM", type_, type_.enums) + + def visit_SET(self, type_, **kw): + return self._visit_enumerated_values("SET", type_, type_.values) + + def visit_BOOLEAN(self, type_, **kw): + return "BOOL" + + +class MySQLIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS_MYSQL + + def __init__(self, dialect, server_ansiquotes=False, **kw): + if not server_ansiquotes: + quote = "`" + else: + quote = '"' + + super(MySQLIdentifierPreparer, self).__init__( + dialect, initial_quote=quote, escape_quote=quote + ) + + def _quote_free_identifiers(self, *ids): + """Unilaterally identifier-quote any number of strings.""" + + return tuple([self.quote_identifier(i) for i in ids if i is not None]) + + +class MariaDBIdentifierPreparer(MySQLIdentifierPreparer): + reserved_words = RESERVED_WORDS_MARIADB + + +@log.class_logger +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. + Not used directly in application code. + """ + + name = "mysql" + supports_statement_cache = True + + supports_alter = True + + # MySQL has no true "boolean" type; we + # allow for the "true" and "false" keywords, however + supports_native_boolean = False + + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + max_index_name_length = 64 + max_constraint_name_length = 64 + + supports_native_enum = True + + supports_sequences = False # default for MySQL ... + # ... may be updated to True for MariaDB 10.3+ in initialize() + + sequences_optional = False + + supports_for_update_of = False # default for MySQL ... + # ... may be updated to True for MySQL 8+ in initialize() + + # MySQL doesn't support "DEFAULT VALUES" but *does* support + # "VALUES (DEFAULT)" + supports_default_values = False + supports_default_metavalue = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_multivalues_insert = True + + supports_comments = True + inline_comments = True + default_paramstyle = "format" + colspecs = colspecs + + cte_follows_insert = True + + statement_compiler = MySQLCompiler + ddl_compiler = MySQLDDLCompiler + type_compiler = MySQLTypeCompiler + ischema_names = ischema_names + preparer = MySQLIdentifierPreparer + + is_mariadb = False + _mariadb_normalized_version_info = None + + # default SQL compilation settings - + # these are modified upon initialize(), + # i.e. first connect + _backslash_escapes = True + _server_ansiquotes = False + + construct_arguments = [ + (sa_schema.Table, {"*": None}), + (sql.Update, {"limit": None}), + (sa_schema.PrimaryKeyConstraint, {"using": None}), + ( + sa_schema.Index, + { + "using": None, + "length": None, + "prefix": None, + "with_parser": None, + }, + ), + ] + + def __init__( + self, + isolation_level=None, + json_serializer=None, + json_deserializer=None, + is_mariadb=None, + **kwargs + ): + kwargs.pop("use_ansiquotes", None) # legacy + default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + self._set_mariadb(is_mariadb, None) + + def on_connect(self): + if self.isolation_level is not None: + + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + return connect + else: + return None + + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ] + ) + + def set_isolation_level(self, connection, level): + level = level.replace("_", " ") + + # adjust for ConnectionFairy being present + # allows attribute set e.g. "connection.autocommit = True" + # to work properly + if hasattr(connection, "dbapi_connection"): + connection = connection.dbapi_connection + + self._set_isolation_level(connection, level) + + def _set_isolation_level(self, connection, level): + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + if self._is_mysql and self.server_version_info >= (5, 7, 20): + cursor.execute("SELECT @@transaction_isolation") + else: + cursor.execute("SELECT @@tx_isolation") + row = cursor.fetchone() + if row is None: + util.warn( + "Could not retrieve transaction isolation level for MySQL " + "connection." + ) + raise NotImplementedError() + val = row[0] + cursor.close() + if util.py3k and isinstance(val, bytes): + val = val.decode() + return val.upper().replace("-", " ") + + @classmethod + def _is_mariadb_from_url(cls, url): + dbapi = cls.dbapi() + dialect = cls(dbapi=dbapi) + + cargs, cparams = dialect.create_connect_args(url) + conn = dialect.connect(*cargs, **cparams) + try: + cursor = conn.cursor() + cursor.execute("SELECT VERSION() LIKE '%MariaDB%'") + val = cursor.fetchone()[0] + except: + raise + else: + return bool(val) + finally: + conn.close() + + def _get_server_version_info(self, connection): + # get database server version info explicitly over the wire + # to avoid proxy servers like MaxScale getting in the + # way with their own values, see #4205 + dbapi_con = connection.connection + cursor = dbapi_con.cursor() + cursor.execute("SELECT VERSION()") + val = cursor.fetchone()[0] + cursor.close() + if util.py3k and isinstance(val, bytes): + val = val.decode() + + return self._parse_server_version(val) + + def _parse_server_version(self, val): + version = [] + is_mariadb = False + + r = re.compile(r"[.\-+]") + tokens = r.split(val) + for token in tokens: + parsed_token = re.match( + r"^(?:(\d+)(?:a|b|c)?|(MariaDB\w*))$", token + ) + if not parsed_token: + continue + elif parsed_token.group(2): + self._mariadb_normalized_version_info = tuple(version[-3:]) + is_mariadb = True + else: + digit = int(parsed_token.group(1)) + version.append(digit) + + server_version_info = tuple(version) + + self._set_mariadb(server_version_info and is_mariadb, val) + + if not is_mariadb: + self._mariadb_normalized_version_info = server_version_info + + if server_version_info < (5, 0, 2): + raise NotImplementedError( + "the MySQL/MariaDB dialect supports server " + "version info 5.0.2 and above." + ) + + # setting it here to help w the test suite + self.server_version_info = server_version_info + return server_version_info + + def _set_mariadb(self, is_mariadb, server_version_info): + if is_mariadb is None: + return + + if not is_mariadb and self.is_mariadb: + raise exc.InvalidRequestError( + "MySQL version %s is not a MariaDB variant." + % (server_version_info,) + ) + if is_mariadb: + self.preparer = MariaDBIdentifierPreparer + # this would have been set by the default dialect already, + # so set it again + self.identifier_preparer = self.preparer(self) + self.is_mariadb = is_mariadb + + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid)) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid"), dict(xid=xid)) + connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid)) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + connection.execute(sql.text("XA END :xid"), dict(xid=xid)) + connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid)) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid)) + + def do_recover_twophase(self, connection): + resultset = connection.exec_driver_sql("XA RECOVER") + return [row["data"][0 : row["gtrid_length"]] for row in resultset] + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, + ( + self.dbapi.OperationalError, + self.dbapi.ProgrammingError, + self.dbapi.InterfaceError, + ), + ) and self._extract_error_code(e) in ( + 1927, + 2006, + 2013, + 2014, + 2045, + 2055, + 4031, + ): + return True + elif isinstance( + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) + ): + # if underlying connection is closed, + # this is the error you get + return "(0, '')" in str(e) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver + inconsistencies.""" + + return [_DecodingRow(row, charset) for row in rp.fetchall()] + + def _compat_fetchone(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver + inconsistencies.""" + + row = rp.fetchone() + if row: + return _DecodingRow(row, charset) + else: + return None + + def _compat_first(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver + inconsistencies.""" + + row = rp.first() + if row: + return _DecodingRow(row, charset) + else: + return None + + def _extract_error_code(self, exception): + raise NotImplementedError() + + def _get_default_schema_name(self, connection): + return connection.exec_driver_sql("SELECT DATABASE()").scalar() + + def has_table(self, connection, table_name, schema=None): + self._ensure_has_table_connection(connection) + + if schema is None: + schema = self.default_schema_name + + rs = connection.execute( + text( + "SELECT COUNT(*) FROM information_schema.tables WHERE " + "table_schema = :table_schema AND " + "table_name = :table_name" + ).bindparams( + sql.bindparam("table_schema", type_=Unicode), + sql.bindparam("table_name", type_=Unicode), + ), + { + "table_schema": util.text_type(schema), + "table_name": util.text_type(table_name), + }, + ) + return bool(rs.scalar()) + + def has_sequence(self, connection, sequence_name, schema=None): + if not self.supports_sequences: + self._sequences_not_supported() + if not schema: + schema = self.default_schema_name + # MariaDB implements sequences as a special type of table + # + cursor = connection.execute( + sql.text( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE='SEQUENCE' and TABLE_NAME=:name AND " + "TABLE_SCHEMA=:schema_name" + ), + dict( + name=util.text_type(sequence_name), + schema_name=util.text_type(schema), + ), + ) + return cursor.first() is not None + + def _sequences_not_supported(self): + raise NotImplementedError( + "Sequences are supported only by the " + "MariaDB series 10.3 or greater" + ) + + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + if not self.supports_sequences: + self._sequences_not_supported() + if not schema: + schema = self.default_schema_name + # MariaDB implements sequences as a special type of table + cursor = connection.execute( + sql.text( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE='SEQUENCE' and TABLE_SCHEMA=:schema_name" + ), + dict(schema_name=schema), + ) + return [ + row[0] + for row in self._compat_fetchall( + cursor, charset=self._connection_charset + ) + ] + + def initialize(self, connection): + # this is driver-based, does not need server version info + # and is fairly critical for even basic SQL operations + self._connection_charset = self._detect_charset(connection) + + # call super().initialize() because we need to have + # server_version_info set up. in 1.4 under python 2 only this does the + # "check unicode returns" thing, which is the one area that some + # SQL gets compiled within initialize() currently + default.DefaultDialect.initialize(self, connection) + + self._detect_sql_mode(connection) + self._detect_ansiquotes(connection) # depends on sql mode + self._detect_casing(connection) + if self._server_ansiquotes: + # if ansiquotes == True, build a new IdentifierPreparer + # with the new setting + self.identifier_preparer = self.preparer( + self, server_ansiquotes=self._server_ansiquotes + ) + + self.supports_sequences = ( + self.is_mariadb and self.server_version_info >= (10, 3) + ) + + self.supports_for_update_of = ( + self._is_mysql and self.server_version_info >= (8,) + ) + + self._needs_correct_for_88718_96365 = ( + not self.is_mariadb and self.server_version_info >= (8,) + ) + + self._warn_for_known_db_issues() + + def _warn_for_known_db_issues(self): + if self.is_mariadb: + mdb_version = self._mariadb_normalized_version_info + if mdb_version > (10, 2) and mdb_version < (10, 2, 9): + util.warn( + "MariaDB %r before 10.2.9 has known issues regarding " + "CHECK constraints, which impact handling of NULL values " + "with SQLAlchemy's boolean datatype (MDEV-13596). An " + "additional issue prevents proper migrations of columns " + "with CHECK constraints (MDEV-11114). Please upgrade to " + "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " + "series, to avoid these issues." % (mdb_version,) + ) + + @property + def _support_float_cast(self): + if not self.server_version_info: + return False + elif self.is_mariadb: + # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/ + return self.server_version_info >= (10, 4, 5) + else: + # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa + return self.server_version_info >= (8, 0, 17) + + @property + def _is_mariadb(self): + return self.is_mariadb + + @property + def _is_mysql(self): + return not self.is_mariadb + + @property + def _is_mariadb_102(self): + return self.is_mariadb and self._mariadb_normalized_version_info > ( + 10, + 2, + ) + + @reflection.cache + def get_schema_names(self, connection, **kw): + rp = connection.exec_driver_sql("SHOW schemas") + return [r[0] for r in rp] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + """Return a Unicode SHOW TABLES from a given schema.""" + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + charset = self._connection_charset + + rp = connection.exec_driver_sql( + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) + + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == "BASE TABLE" + ] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + charset = self._connection_charset + rp = connection.exec_driver_sql( + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(schema) + ) + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ("VIEW", "SYSTEM VIEW") + ] + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + return parsed_state.table_options + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + return parsed_state.columns + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + for key in parsed_state.keys: + if key["type"] == "PRIMARY": + # There can be only one. + cols = [s[0] for s in key["columns"]] + return {"constrained_columns": cols, "name": None} + return {"constrained_columns": [], "name": None} + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + default_schema = None + + fkeys = [] + + for spec in parsed_state.fk_constraints: + ref_name = spec["table"][-1] + ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema + + if not ref_schema: + if default_schema is None: + default_schema = connection.dialect.default_schema_name + if schema == default_schema: + ref_schema = schema + + loc_names = spec["local"] + ref_names = spec["foreign"] + + con_kw = {} + for opt in ("onupdate", "ondelete"): + if spec.get(opt, False) not in ("NO ACTION", None): + con_kw[opt] = spec[opt] + + fkey_d = { + "name": spec["name"], + "constrained_columns": loc_names, + "referred_schema": ref_schema, + "referred_table": ref_name, + "referred_columns": ref_names, + "options": con_kw, + } + fkeys.append(fkey_d) + + if self._needs_correct_for_88718_96365: + self._correct_for_mysql_bugs_88718_96365(fkeys, connection) + + return fkeys + + def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): + # Foreign key is always in lower case (MySQL 8.0) + # https://bugs.mysql.com/bug.php?id=88718 + # issue #4344 for SQLAlchemy + + # table name also for MySQL 8.0 + # https://bugs.mysql.com/bug.php?id=96365 + # issue #4751 for SQLAlchemy + + # for lower_case_table_names=2, information_schema.columns + # preserves the original table/schema casing, but SHOW CREATE + # TABLE does not. this problem is not in lower_case_table_names=1, + # but use case-insensitive matching for these two modes in any case. + + if self._casing in (1, 2): + + def lower(s): + return s.lower() + + else: + # if on case sensitive, there can be two tables referenced + # with the same name different casing, so we need to use + # case-sensitive matching. + def lower(s): + return s + + default_schema_name = connection.dialect.default_schema_name + col_tuples = [ + ( + lower(rec["referred_schema"] or default_schema_name), + lower(rec["referred_table"]), + col_name, + ) + for rec in fkeys + for col_name in rec["referred_columns"] + ] + + if col_tuples: + + correct_for_wrong_fk_case = connection.execute( + sql.text( + """ + select table_schema, table_name, column_name + from information_schema.columns + where (table_schema, table_name, lower(column_name)) in + :table_data; + """ + ).bindparams(sql.bindparam("table_data", expanding=True)), + dict(table_data=col_tuples), + ) + + # in casing=0, table name and schema name come back in their + # exact case. + # in casing=1, table name and schema name come back in lower + # case. + # in casing=2, table name and schema name come back from the + # information_schema.columns view in the case + # that was used in CREATE DATABASE and CREATE TABLE, but + # SHOW CREATE TABLE converts them to *lower case*, therefore + # not matching. So for this case, case-insensitive lookup + # is necessary + d = defaultdict(dict) + for schema, tname, cname in correct_for_wrong_fk_case: + d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema + d[(lower(schema), lower(tname))]["TABLENAME"] = tname + d[(lower(schema), lower(tname))][cname.lower()] = cname + + for fkey in fkeys: + rec = d[ + ( + lower(fkey["referred_schema"] or default_schema_name), + lower(fkey["referred_table"]), + ) + ] + + fkey["referred_table"] = rec["TABLENAME"] + if fkey["referred_schema"] is not None: + fkey["referred_schema"] = rec["SCHEMANAME"] + + fkey["referred_columns"] = [ + rec[col.lower()] for col in fkey["referred_columns"] + ] + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + return [ + {"name": spec["name"], "sqltext": spec["sqltext"]} + for spec in parsed_state.ck_constraints + ] + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + return { + "text": parsed_state.table_options.get( + "%s_comment" % self.name, None + ) + } + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + indexes = [] + + for spec in parsed_state.keys: + dialect_options = {} + unique = False + flavor = spec["type"] + if flavor == "PRIMARY": + continue + if flavor == "UNIQUE": + unique = True + elif flavor in ("FULLTEXT", "SPATIAL"): + dialect_options["%s_prefix" % self.name] = flavor + elif flavor is None: + pass + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY", flavor + ) + pass + + if spec["parser"]: + dialect_options["%s_with_parser" % (self.name)] = spec[ + "parser" + ] + + index_d = {} + if dialect_options: + index_d["dialect_options"] = dialect_options + + index_d["name"] = spec["name"] + index_d["column_names"] = [s[0] for s in spec["columns"]] + index_d["unique"] = unique + if flavor: + index_d["type"] = flavor + indexes.append(index_d) + return indexes + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + return [ + { + "name": key["name"], + "column_names": [col[0] for col in key["columns"]], + "duplicates_index": key["name"], + } + for key in parsed_state.keys + if key["type"] == "UNIQUE" + ] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + + charset = self._connection_charset + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers(schema, view_name) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + return sql + + def _parsed_state_or_create( + self, connection, table_name, schema=None, **kw + ): + return self._setup_parser( + connection, + table_name, + schema, + info_cache=kw.get("info_cache", None), + ) + + @util.memoized_property + def _tabledef_parser(self): + """return the MySQLTableDefinitionParser, generate if needed. + + The deferred creation ensures that the dialect has + retrieved server version information first. + + """ + preparer = self.identifier_preparer + return _reflection.MySQLTableDefinitionParser(self, preparer) + + @reflection.cache + def _setup_parser(self, connection, table_name, schema=None, **kw): + charset = self._connection_charset + parser = self._tabledef_parser + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql): + # Adapt views to something table-like. + columns = self._describe_table( + connection, None, charset, full_name=full_name + ) + sql = parser._describe_to_create(table_name, columns) + return parser.parse(sql, charset) + + def _fetch_setting(self, connection, setting_name): + charset = self._connection_charset + + if self.server_version_info and self.server_version_info < (5, 6): + sql = "SHOW VARIABLES LIKE '%s'" % setting_name + fetch_col = 1 + else: + sql = "SELECT @@%s" % setting_name + fetch_col = 0 + + show_var = connection.exec_driver_sql(sql) + row = self._compat_first(show_var, charset=charset) + if not row: + return None + else: + return row[fetch_col] + + def _detect_charset(self, connection): + raise NotImplementedError() + + def _detect_casing(self, connection): + """Sniff out identifier case sensitivity. + + Cached per-connection. This value can not change without a server + restart. + + """ + # https://dev.mysql.com/doc/refman/en/identifier-case-sensitivity.html + + setting = self._fetch_setting(connection, "lower_case_table_names") + if setting is None: + cs = 0 + else: + # 4.0.15 returns OFF or ON according to [ticket:489] + # 3.23 doesn't, 4.0.27 doesn't.. + if setting == "OFF": + cs = 0 + elif setting == "ON": + cs = 1 + else: + cs = int(setting) + self._casing = cs + return cs + + def _detect_collations(self, connection): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + collations = {} + charset = self._connection_charset + rs = connection.exec_driver_sql("SHOW COLLATION") + for row in self._compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + + def _detect_sql_mode(self, connection): + setting = self._fetch_setting(connection, "sql_mode") + + if setting is None: + util.warn( + "Could not retrieve SQL_MODE; please ensure the " + "MySQL user has permissions to SHOW VARIABLES" + ) + self._sql_mode = "" + else: + self._sql_mode = setting or "" + + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" + + mode = self._sql_mode + if not mode: + mode = "" + elif mode.isdigit(): + mode_no = int(mode) + mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or "" + + self._server_ansiquotes = "ANSI_QUOTES" in mode + + # as of MySQL 5.0.1 + self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode + + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): + """Run SHOW CREATE TABLE for a ``Table``.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "SHOW CREATE TABLE %s" % full_name + + rp = None + try: + rp = connection.execution_options( + skip_user_error_events=True + ).exec_driver_sql(st) + except exc.DBAPIError as e: + if self._extract_error_code(e.orig) == 1146: + util.raise_(exc.NoSuchTableError(full_name), replace_context=e) + else: + raise + row = self._compat_first(rp, charset=charset) + if not row: + raise exc.NoSuchTableError(full_name) + return row[1].strip() + + def _describe_table(self, connection, table, charset=None, full_name=None): + """Run DESCRIBE for a ``Table`` and return processed rows.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "DESCRIBE %s" % full_name + + rp, rows = None, None + try: + try: + rp = connection.execution_options( + skip_user_error_events=True + ).exec_driver_sql(st) + except exc.DBAPIError as e: + code = self._extract_error_code(e.orig) + if code == 1146: + util.raise_( + exc.NoSuchTableError(full_name), replace_context=e + ) + elif code == 1356: + util.raise_( + exc.UnreflectableTableError( + "Table or view named %s could not be " + "reflected: %s" % (full_name, e) + ), + replace_context=e, + ) + else: + raise + rows = self._compat_fetchall(rp, charset=charset) + finally: + if rp: + rp.close() + return rows + + +class _DecodingRow(object): + """Return unicode-decoded values based on type inspection. + + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. + + """ + + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. + + _encoding_compat = { + "koi8r": "koi8_r", + "koi8u": "koi8_u", + "utf16": "utf-16-be", # MySQL's uft16 is always bigendian + "utf8mb4": "utf8", # real utf8 + "utf8mb3": "utf8", # real utf8; saw this happen on CI but I cannot + # reproduce, possibly mariadb10.6 related + "eucjpms": "ujis", + } + + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = self._encoding_compat.get(charset, charset) + + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + + if self.charset and isinstance(item, util.binary_type): + return item.decode(self.charset) + else: + return item + + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, util.binary_type): + return item.decode(self.charset) + else: + return item diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py new file mode 100644 index 0000000..a67a194 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -0,0 +1,82 @@ +# mysql/cymysql.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" + +.. dialect:: mysql+cymysql + :name: CyMySQL + :dbapi: cymysql + :connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>] + :url: https://github.com/nakagami/CyMySQL + +.. note:: + + The CyMySQL dialect is **not tested as part of SQLAlchemy's continuous + integration** and may have unresolved issues. The recommended MySQL + dialects are mysqlclient and PyMySQL. + +""" # noqa + +from .base import BIT +from .base import MySQLDialect +from .mysqldb import MySQLDialect_mysqldb +from ... import util + + +class _cymysqlBIT(BIT): + def result_processor(self, dialect, coltype): + """Convert MySQL's 64 bit, variable length binary string to a long.""" + + def process(value): + if value is not None: + v = 0 + for i in util.iterbytes(value): + v = v << 8 | i + return v + return value + + return process + + +class MySQLDialect_cymysql(MySQLDialect_mysqldb): + driver = "cymysql" + supports_statement_cache = True + + description_encoding = None + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_unicode_statements = True + + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) + + @classmethod + def dbapi(cls): + return __import__("cymysql") + + def _detect_charset(self, connection): + return connection.connection.charset + + def _extract_error_code(self, exception): + return exception.errno + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.OperationalError): + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) + elif isinstance(e, self.dbapi.InterfaceError): + # if underlying connection is closed, + # this is the error you get + return True + else: + return False + + +dialect = MySQLDialect_cymysql diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py new file mode 100644 index 0000000..0c8791a --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -0,0 +1,175 @@ +from ... import exc +from ... import util +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.expression import alias +from ...util.langhelpers import public_factory + + +__all__ = ("Insert", "insert") + + +class Insert(StandardInsert): + """MySQL-specific implementation of INSERT. + + Adds methods for MySQL-specific syntaxes such as ON DUPLICATE KEY UPDATE. + + The :class:`~.mysql.Insert` object is created using the + :func:`sqlalchemy.dialects.mysql.insert` function. + + .. versionadded:: 1.2 + + """ + + stringify_dialect = "mysql" + inherit_cache = False + + @property + def inserted(self): + """Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE + statement + + MySQL's ON DUPLICATE KEY UPDATE clause allows reference to the row + that would be inserted, via a special function called ``VALUES()``. + This attribute provides all columns in this row to be referenceable + such that they will render within a ``VALUES()`` function inside the + ON DUPLICATE KEY UPDATE clause. The attribute is named ``.inserted`` + so as not to conflict with the existing + :meth:`_expression.Insert.values` method. + + .. tip:: The :attr:`_mysql.Insert.inserted` attribute is an instance + of :class:`_expression.ColumnCollection`, which provides an + interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.inserted.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.inserted["column name"]`` or + ``stmt.inserted["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + .. seealso:: + + :ref:`mysql_insert_on_duplicate_key_update` - example of how + to use :attr:`_expression.Insert.inserted` + + """ + return self.inserted_alias.columns + + @util.memoized_property + def inserted_alias(self): + return alias(self.table, name="inserted") + + @_generative + @_exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already " + "has an ON DUPLICATE KEY clause present" + }, + ) + def on_duplicate_key_update(self, *args, **kw): + r""" + Specifies the ON DUPLICATE KEY UPDATE clause. + + :param \**kw: Column keys linked to UPDATE values. The + values may be any SQL expression or supported literal Python + values. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON DUPLICATE KEY UPDATE + style of UPDATE, unless values are manually specified here. + + :param \*args: As an alternative to passing key/value parameters, + a dictionary or list of 2-tuples can be passed as a single positional + argument. + + Passing a single dictionary is equivalent to the keyword argument + form:: + + insert().on_duplicate_key_update({"name": "some name"}) + + Passing a list of 2-tuples indicates that the parameter assignments + in the UPDATE clause should be ordered as sent, in a manner similar + to that described for the :class:`_expression.Update` + construct overall + in :ref:`tutorial_parameter_ordered_updates`:: + + insert().on_duplicate_key_update( + [("name", "some name"), ("value", "some value")]) + + .. versionchanged:: 1.3 parameters can be specified as a dictionary + or list of 2-tuples; the latter form provides for parameter + ordering. + + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`mysql_insert_on_duplicate_key_update` + + """ + if args and kw: + raise exc.ArgumentError( + "Can't pass kwargs and positional arguments simultaneously" + ) + + if args: + if len(args) > 1: + raise exc.ArgumentError( + "Only a single dictionary or list of tuples " + "is accepted positionally." + ) + values = args[0] + else: + values = kw + + inserted_alias = getattr(self, "inserted_alias", None) + self._post_values_clause = OnDuplicateClause(inserted_alias, values) + + +insert = public_factory( + Insert, ".dialects.mysql.insert", ".dialects.mysql.Insert" +) + + +class OnDuplicateClause(ClauseElement): + __visit_name__ = "on_duplicate_key_update" + + _parameter_ordering = None + + stringify_dialect = "mysql" + + def __init__(self, inserted_alias, update): + self.inserted_alias = inserted_alias + + # auto-detect that parameters should be ordered. This is copied from + # Update._proces_colparams(), however we don't look for a special flag + # in this case since we are not disambiguating from other use cases as + # we are in Update.values(). + if isinstance(update, list) and ( + update and isinstance(update[0], tuple) + ): + self._parameter_ordering = [key for key, value in update] + update = dict(update) + + if isinstance(update, dict): + if not update: + raise ValueError( + "update parameter dictionary must not be empty" + ) + elif isinstance(update, ColumnCollection): + update = dict(update) + else: + raise ValueError( + "update parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update = update diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py new file mode 100644 index 0000000..6c9ef28 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -0,0 +1,263 @@ +# mysql/enumerated.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import re + +from .types import _StringType +from ... import exc +from ... import sql +from ... import util +from ...sql import sqltypes +from ...sql.base import NO_ARG + + +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): + """MySQL ENUM type.""" + + __visit_name__ = "ENUM" + + native_enum = True + + def __init__(self, *enums, **kw): + """Construct an ENUM. + + E.g.:: + + Column('myenum', ENUM("foo", "bar", "baz")) + + :param enums: The range of valid values for this ENUM. Values in + enums are not quoted, they will be escaped and surrounded by single + quotes when generating the schema. This object may also be a + PEP-435-compliant enumerated type. + + .. versionadded: 1.1 added support for PEP-435-compliant enumerated + types. + + :param strict: This flag has no effect. + + .. versionchanged:: The MySQL ENUM type as well as the base Enum + type now validates all Python data values. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + :param quoting: Not used. A warning will be raised if provided. + + """ + if kw.pop("quoting", NO_ARG) is not NO_ARG: + util.warn_deprecated_20( + "The 'quoting' parameter to :class:`.mysql.ENUM` is deprecated" + " and will be removed in a future release. " + "This parameter now has no effect." + ) + kw.pop("strict", None) + self._enum_init(enums, kw) + _StringType.__init__(self, length=self.length, **kw) + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a MySQL native :class:`.mysql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + return cls(**kw) + + def _object_value_for_elem(self, elem): + # mysql sends back a blank string for any value that + # was persisted that was not in the enums; that is, it does no + # validation on the incoming data, it "truncates" it to be + # the blank string. Return it straight. + if elem == "": + return elem + else: + return super(ENUM, self)._object_value_for_elem(elem) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[ENUM, _StringType, sqltypes.Enum] + ) + + +class SET(_StringType): + """MySQL SET type.""" + + __visit_name__ = "SET" + + def __init__(self, *values, **kw): + """Construct a SET. + + E.g.:: + + Column('myset', SET("foo", "bar", "baz")) + + + The list of potential values is required in the case that this + set will be used to generate DDL for a table, or if the + :paramref:`.SET.retrieve_as_bitwise` flag is set to True. + + :param values: The range of valid values for this SET. The values + are not quoted, they will be escaped and surrounded by single + quotes when generating the schema. + + :param convert_unicode: Same flag as that of + :paramref:`.String.convert_unicode`. + + :param collation: same as that of :paramref:`.String.collation` + + :param charset: same as that of :paramref:`.VARCHAR.charset`. + + :param ascii: same as that of :paramref:`.VARCHAR.ascii`. + + :param unicode: same as that of :paramref:`.VARCHAR.unicode`. + + :param binary: same as that of :paramref:`.VARCHAR.binary`. + + :param retrieve_as_bitwise: if True, the data for the set type will be + persisted and selected using an integer value, where a set is coerced + into a bitwise mask for persistence. MySQL allows this mode which + has the advantage of being able to store values unambiguously, + such as the blank string ``''``. The datatype will appear + as the expression ``col + 0`` in a SELECT statement, so that the + value is coerced into an integer value in result sets. + This flag is required if one wishes + to persist a set that can store the blank string ``''`` as a value. + + .. warning:: + + When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is + essential that the list of set values is expressed in the + **exact same order** as exists on the MySQL database. + + .. versionadded:: 1.0.0 + + :param quoting: Not used. A warning will be raised if passed. + + """ + if kw.pop("quoting", NO_ARG) is not NO_ARG: + util.warn_deprecated_20( + "The 'quoting' parameter to :class:`.mysql.SET` is deprecated" + " and will be removed in a future release. " + "This parameter now has no effect." + ) + self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False) + self.values = tuple(values) + if not self.retrieve_as_bitwise and "" in values: + raise exc.ArgumentError( + "Can't use the blank value '' in a SET without " + "setting retrieve_as_bitwise=True" + ) + if self.retrieve_as_bitwise: + self._bitmap = dict( + (value, 2 ** idx) for idx, value in enumerate(self.values) + ) + self._bitmap.update( + (2 ** idx, value) for idx, value in enumerate(self.values) + ) + length = max([len(v) for v in values] + [0]) + kw.setdefault("length", length) + super(SET, self).__init__(**kw) + + def column_expression(self, colexpr): + if self.retrieve_as_bitwise: + return sql.type_coerce( + sql.type_coerce(colexpr, sqltypes.Integer) + 0, self + ) + else: + return colexpr + + def result_processor(self, dialect, coltype): + if self.retrieve_as_bitwise: + + def process(value): + if value is not None: + value = int(value) + + return set(util.map_bits(self._bitmap.__getitem__, value)) + else: + return None + + else: + super_convert = super(SET, self).result_processor(dialect, coltype) + + def process(value): + if isinstance(value, util.string_types): + # MySQLdb returns a string, let's parse + if super_convert: + value = super_convert(value) + return set(re.findall(r"[^,]+", value)) + else: + # mysql-connector-python does a naive + # split(",") which throws in an empty string + if value is not None: + value.discard("") + return value + + return process + + def bind_processor(self, dialect): + super_convert = super(SET, self).bind_processor(dialect) + if self.retrieve_as_bitwise: + + def process(value): + if value is None: + return None + elif isinstance(value, util.int_types + util.string_types): + if super_convert: + return super_convert(value) + else: + return value + else: + int_value = 0 + for v in value: + int_value |= self._bitmap[v] + return int_value + + else: + + def process(value): + # accept strings and int (actually bitflag) values directly + if value is not None and not isinstance( + value, util.int_types + util.string_types + ): + value = ",".join(value) + + if super_convert: + return super_convert(value) + else: + return value + + return process + + def adapt(self, impltype, **kw): + kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise + return util.constructor_copy(self, impltype, *self.values, **kw) + + def __repr__(self): + return util.generic_repr( + self, + to_inspect=[SET, _StringType], + additional_kw=[ + ("retrieve_as_bitwise", False), + ], + ) diff --git a/lib/sqlalchemy/dialects/mysql/expression.py b/lib/sqlalchemy/dialects/mysql/expression.py new file mode 100644 index 0000000..7a66e9b --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/expression.py @@ -0,0 +1,130 @@ +from ... import exc +from ... import util +from ...sql import coercions +from ...sql import elements +from ...sql import operators +from ...sql import roles +from ...sql.base import _generative +from ...sql.base import Generative + + +class match(Generative, elements.BinaryExpression): + """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause. + + E.g.:: + + from sqlalchemy import desc + from sqlalchemy.dialects.mysql import match + + match_expr = match( + users_table.c.firstname, + users_table.c.lastname, + against="Firstname Lastname", + ) + + stmt = ( + select(users_table) + .where(match_expr.in_boolean_mode()) + .order_by(desc(match_expr)) + ) + + Would produce SQL resembling:: + + SELECT id, firstname, lastname + FROM user + WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE) + ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC + + The :func:`_mysql.match` function is a standalone version of the + :meth:`_sql.ColumnElement.match` method available on all + SQL expressions, as when :meth:`_expression.ColumnElement.match` is + used, but allows to pass multiple columns + + :param cols: column expressions to match against + + :param against: expression to be compared towards + + :param in_boolean_mode: boolean, set "boolean mode" to true + + :param in_natural_language_mode: boolean , set "natural language" to true + + :param with_query_expansion: boolean, set "query expansion" to true + + .. versionadded:: 1.4.19 + + .. seealso:: + + :meth:`_expression.ColumnElement.match` + + """ + + __visit_name__ = "mysql_match" + + inherit_cache = True + + def __init__(self, *cols, **kw): + if not cols: + raise exc.ArgumentError("columns are required") + + against = kw.pop("against", None) + + if against is None: + raise exc.ArgumentError("against is required") + against = coercions.expect( + roles.ExpressionElementRole, + against, + ) + + left = elements.BooleanClauseList._construct_raw( + operators.comma_op, + clauses=cols, + ) + left.group = False + + flags = util.immutabledict( + { + "mysql_boolean_mode": kw.pop("in_boolean_mode", False), + "mysql_natural_language": kw.pop( + "in_natural_language_mode", False + ), + "mysql_query_expansion": kw.pop("with_query_expansion", False), + } + ) + + if kw: + raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw))) + + super(match, self).__init__( + left, against, operators.match_op, modifiers=flags + ) + + @_generative + def in_boolean_mode(self): + """Apply the "IN BOOLEAN MODE" modifier to the MATCH expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_boolean_mode": True}) + + @_generative + def in_natural_language_mode(self): + """Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH + expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_natural_language": True}) + + @_generative + def with_query_expansion(self): + """Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_query_expansion": True}) diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py new file mode 100644 index 0000000..857fcce --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/json.py @@ -0,0 +1,84 @@ +# mysql/json.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import absolute_import + +from ... import types as sqltypes + + +class JSON(sqltypes.JSON): + """MySQL JSON type. + + MySQL supports JSON as of version 5.7. + MariaDB supports JSON (as an alias for LONGTEXT) as of version 10.2. + + :class:`_mysql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a MySQL or MariaDB backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`.mysql.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_EXTRACT`` + function at the database level. + + .. versionadded:: 1.1 + + """ + + pass + + +class _FormatTypeMixin(object): + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/lib/sqlalchemy/dialects/mysql/mariadb.py b/lib/sqlalchemy/dialects/mysql/mariadb.py new file mode 100644 index 0000000..568c3f0 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mariadb.py @@ -0,0 +1,25 @@ +from .base import MariaDBIdentifierPreparer +from .base import MySQLDialect + + +class MariaDBDialect(MySQLDialect): + is_mariadb = True + supports_statement_cache = True + name = "mariadb" + preparer = MariaDBIdentifierPreparer + + +def loader(driver): + driver_mod = __import__( + "sqlalchemy.dialects.mysql.%s" % driver + ).dialects.mysql + driver_cls = getattr(driver_mod, driver).dialect + + return type( + "MariaDBDialect_%s" % driver, + ( + MariaDBDialect, + driver_cls, + ), + {"supports_statement_cache": True}, + ) diff --git a/lib/sqlalchemy/dialects/mysql/mariadbconnector.py b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py new file mode 100644 index 0000000..c8b2ead --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -0,0 +1,240 @@ +# mysql/mariadbconnector.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql+mariadbconnector + :name: MariaDB Connector/Python + :dbapi: mariadb + :connectstring: mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname> + :url: https://pypi.org/project/mariadb/ + +Driver Status +------------- + +MariaDB Connector/Python enables Python programs to access MariaDB and MySQL +databases using an API which is compliant with the Python DB API 2.0 (PEP-249). +It is written in C and uses MariaDB Connector/C client library for client server +communication. + +Note that the default driver for a ``mariadb://`` connection URI continues to +be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver. + +.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python + +""" # noqa +import re + +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLExecutionContext +from ... import sql +from ... import util + +mariadb_cpy_minimum_version = (1, 0, 1) + + +class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext): + _lastrowid = None + + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(buffered=False) + + def create_default_cursor(self): + return self._dbapi_connection.cursor(buffered=True) + + def post_exec(self): + if self.isinsert and self.compiled.postfetch_lastrowid: + self._lastrowid = self.cursor.lastrowid + + def get_lastrowid(self): + return self._lastrowid + + +class MySQLCompiler_mariadbconnector(MySQLCompiler): + pass + + +class MySQLDialect_mariadbconnector(MySQLDialect): + driver = "mariadbconnector" + supports_statement_cache = True + + # set this to True at the module level to prevent the driver from running + # against a backend that server detects as MySQL. currently this appears to + # be unnecessary as MariaDB client libraries have always worked against + # MySQL databases. However, if this changes at some point, this can be + # adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if + # this change is made at some point to ensure the correct exception + # is raised at the correct point when running the driver against + # a MySQL backend. + # is_mariadb = True + + supports_unicode_statements = True + encoding = "utf8mb4" + convert_unicode = True + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_native_decimal = True + default_paramstyle = "qmark" + execution_ctx_cls = MySQLExecutionContext_mariadbconnector + statement_compiler = MySQLCompiler_mariadbconnector + + supports_server_side_cursors = True + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + def __init__(self, **kwargs): + super(MySQLDialect_mariadbconnector, self).__init__(**kwargs) + self.paramstyle = "qmark" + if self.dbapi is not None: + if self._dbapi_version < mariadb_cpy_minimum_version: + raise NotImplementedError( + "The minimum required version for MariaDB " + "Connector/Python is %s" + % ".".join(str(x) for x in mariadb_cpy_minimum_version) + ) + + @classmethod + def dbapi(cls): + return __import__("mariadb") + + def is_disconnect(self, e, connection, cursor): + if super(MySQLDialect_mariadbconnector, self).is_disconnect( + e, connection, cursor + ): + return True + elif isinstance(e, self.dbapi.Error): + str_e = str(e).lower() + return "not connected" in str_e or "isn't valid" in str_e + else: + return False + + def create_connect_args(self, url): + opts = url.translate_connect_args() + + int_params = [ + "connect_timeout", + "read_timeout", + "write_timeout", + "client_flag", + "port", + "pool_size", + ] + bool_params = [ + "local_infile", + "ssl_verify_cert", + "ssl", + "pool_reset_connection", + ] + + for key in int_params: + util.coerce_kw_type(opts, key, int) + for key in bool_params: + util.coerce_kw_type(opts, key, bool) + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get("client_flag", 0) + if self.dbapi is not None: + try: + CLIENT_FLAGS = __import__( + self.dbapi.__name__ + ".constants.CLIENT" + ).constants.CLIENT + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except (AttributeError, ImportError): + self.supports_sane_rowcount = False + opts["client_flag"] = client_flag + return [[], opts] + + def _extract_error_code(self, exception): + try: + rc = exception.errno + except: + rc = -1 + return rc + + def _detect_charset(self, connection): + return "utf8mb4" + + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ] + ) + + def _set_isolation_level(self, connection, level): + if level == "AUTOCOMMIT": + connection.autocommit = True + else: + connection.autocommit = False + super(MySQLDialect_mariadbconnector, self)._set_isolation_level( + connection, level + ) + + def do_begin_twophase(self, connection, xid): + connection.execute( + sql.text("XA BEGIN :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_prepare_twophase(self, connection, xid): + connection.execute( + sql.text("XA END :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + connection.execute( + sql.text("XA PREPARE :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + connection.execute( + sql.text("XA END :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + connection.execute( + sql.text("XA ROLLBACK :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute( + sql.text("XA COMMIT :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + +dialect = MySQLDialect_mariadbconnector diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py new file mode 100644 index 0000000..356babe --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -0,0 +1,240 @@ +# mysql/mysqlconnector.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" +.. dialect:: mysql+mysqlconnector + :name: MySQL Connector/Python + :dbapi: myconnpy + :connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname> + :url: https://pypi.org/project/mysql-connector-python/ + +.. note:: + + The MySQL Connector/Python DBAPI has had many issues since its release, + some of which may remain unresolved, and the mysqlconnector dialect is + **not tested as part of SQLAlchemy's continuous integration**. + The recommended MySQL dialects are mysqlclient and PyMySQL. + +""" # noqa + +import re + +from .base import BIT +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLIdentifierPreparer +from ... import processors +from ... import util + + +class MySQLCompiler_mysqlconnector(MySQLCompiler): + def visit_mod_binary(self, binary, operator, **kw): + if self.dialect._mysqlconnector_double_percents: + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) + else: + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) + + def post_process_text(self, text): + if self.dialect._mysqlconnector_double_percents: + return text.replace("%", "%%") + else: + return text + + def escape_literal_column(self, text): + if self.dialect._mysqlconnector_double_percents: + return text.replace("%", "%%") + else: + return text + + +class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): + @property + def _double_percents(self): + return self.dialect._mysqlconnector_double_percents + + @_double_percents.setter + def _double_percents(self, value): + pass + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + if self.dialect._mysqlconnector_double_percents: + return value.replace("%", "%%") + else: + return value + + +class _myconnpyBIT(BIT): + def result_processor(self, dialect, coltype): + """MySQL-connector already converts mysql bits, so.""" + + return None + + +class MySQLDialect_mysqlconnector(MySQLDialect): + driver = "mysqlconnector" + supports_statement_cache = True + + supports_unicode_binds = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = "format" + statement_compiler = MySQLCompiler_mysqlconnector + + preparer = MySQLIdentifierPreparer_mysqlconnector + + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) + + def __init__(self, *arg, **kw): + super(MySQLDialect_mysqlconnector, self).__init__(*arg, **kw) + + # hack description encoding since mysqlconnector randomly + # returns bytes or not + self._description_decoder = ( + processors.to_conditional_unicode_processor_factory + )(self.description_encoding) + + def _check_unicode_description(self, connection): + # hack description encoding since mysqlconnector randomly + # returns bytes or not + return False + + @property + def description_encoding(self): + # total guess + return "latin-1" + + @util.memoized_property + def supports_unicode_statements(self): + return util.py3k or self._mysqlconnector_version_info > (2, 0) + + @classmethod + def dbapi(cls): + from mysql import connector + + return connector + + def do_ping(self, dbapi_connection): + try: + dbapi_connection.ping(False) + except self.dbapi.Error as err: + if self.is_disconnect(err, dbapi_connection, None): + return False + else: + raise + else: + return True + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + + opts.update(url.query) + + util.coerce_kw_type(opts, "allow_local_infile", bool) + util.coerce_kw_type(opts, "autocommit", bool) + util.coerce_kw_type(opts, "buffered", bool) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connection_timeout", int) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "consume_results", bool) + util.coerce_kw_type(opts, "force_ipv6", bool) + util.coerce_kw_type(opts, "get_warnings", bool) + util.coerce_kw_type(opts, "pool_reset_session", bool) + util.coerce_kw_type(opts, "pool_size", int) + util.coerce_kw_type(opts, "raise_on_warnings", bool) + util.coerce_kw_type(opts, "raw", bool) + util.coerce_kw_type(opts, "ssl_verify_cert", bool) + util.coerce_kw_type(opts, "use_pure", bool) + util.coerce_kw_type(opts, "use_unicode", bool) + + # unfortunately, MySQL/connector python refuses to release a + # cursor without reading fully, so non-buffered isn't an option + opts.setdefault("buffered", True) + + # FOUND_ROWS must be set in ClientFlag to enable + # supports_sane_rowcount. + if self.dbapi is not None: + try: + from mysql.connector.constants import ClientFlag + + client_flags = opts.get( + "client_flags", ClientFlag.get_default() + ) + client_flags |= ClientFlag.FOUND_ROWS + opts["client_flags"] = client_flags + except Exception: + pass + return [[], opts] + + @util.memoized_property + def _mysqlconnector_version_info(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) + + @util.memoized_property + def _mysqlconnector_double_percents(self): + return not util.py3k and self._mysqlconnector_version_info < (2, 0) + + def _detect_charset(self, connection): + return connection.connection.charset + + def _extract_error_code(self, exception): + return exception.errno + + def is_disconnect(self, e, connection, cursor): + errnos = (2006, 2013, 2014, 2045, 2055, 2048) + exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) + if isinstance(e, exceptions): + return ( + e.errno in errnos + or "MySQL Connection not available." in str(e) + or "Connection to MySQL is not available" in str(e) + ) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + return rp.fetchone() + + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ] + ) + + def _set_isolation_level(self, connection, level): + if level == "AUTOCOMMIT": + connection.autocommit = True + else: + connection.autocommit = False + super(MySQLDialect_mysqlconnector, self)._set_isolation_level( + connection, level + ) + + +dialect = MySQLDialect_mysqlconnector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py new file mode 100644 index 0000000..7a721e8 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -0,0 +1,331 @@ +# mysql/mysqldb.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql+mysqldb + :name: mysqlclient (maintained fork of MySQL-Python) + :dbapi: mysqldb + :connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname> + :url: https://pypi.org/project/mysqlclient/ + +Driver Status +------------- + +The mysqlclient DBAPI is a maintained fork of the +`MySQL-Python <https://sourceforge.net/projects/mysql-python>`_ DBAPI +that is no longer maintained. `mysqlclient`_ supports Python 2 and Python 3 +and is very stable. + +.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python + +.. _mysqldb_unicode: + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + +.. _mysqldb_ssl: + +SSL Connections +---------------- + +The mysqlclient and PyMySQL DBAPIs accept an additional dictionary under the +key "ssl", which may be specified using the +:paramref:`_sa.create_engine.connect_args` dictionary:: + + engine = create_engine( + "mysql+mysqldb://scott:tiger@192.168.0.134/test", + connect_args={ + "ssl": { + "ssl_ca": "/home/gord/client-ssl/ca.pem", + "ssl_cert": "/home/gord/client-ssl/client-cert.pem", + "ssl_key": "/home/gord/client-ssl/client-key.pem" + } + } + ) + +For convenience, the following keys may also be specified inline within the URL +where they will be interpreted into the "ssl" dictionary automatically: +"ssl_ca", "ssl_cert", "ssl_key", "ssl_capath", "ssl_cipher", +"ssl_check_hostname". An example is as follows:: + + connection_uri = ( + "mysql+mysqldb://scott:tiger@192.168.0.134/test" + "?ssl_ca=/home/gord/client-ssl/ca.pem" + "&ssl_cert=/home/gord/client-ssl/client-cert.pem" + "&ssl_key=/home/gord/client-ssl/client-key.pem" + ) + +If the server uses an automatically-generated certificate that is self-signed +or does not match the host name (as seen from the client), it may also be +necessary to indicate ``ssl_check_hostname=false``:: + + connection_uri = ( + "mysql+pymysql://scott:tiger@192.168.0.134/test" + "?ssl_ca=/home/gord/client-ssl/ca.pem" + "&ssl_cert=/home/gord/client-ssl/client-cert.pem" + "&ssl_key=/home/gord/client-ssl/client-key.pem" + "&ssl_check_hostname=false" + ) + + +.. seealso:: + + :ref:`pymysql_ssl` in the PyMySQL dialect + + +Using MySQLdb with Google Cloud SQL +----------------------------------- + +Google Cloud SQL now recommends use of the MySQLdb dialect. Connect +using a URL like the following:: + + mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename> + +Server Side Cursors +------------------- + +The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`. + +""" + +import re + +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLExecutionContext +from .base import MySQLIdentifierPreparer +from .base import TEXT +from ... import sql +from ... import util + + +class MySQLExecutionContext_mysqldb(MySQLExecutionContext): + @property + def rowcount(self): + if hasattr(self, "_rowcount"): + return self._rowcount + else: + return self.cursor.rowcount + + +class MySQLCompiler_mysqldb(MySQLCompiler): + pass + + +class MySQLDialect_mysqldb(MySQLDialect): + driver = "mysqldb" + supports_statement_cache = True + supports_unicode_statements = True + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = "format" + execution_ctx_cls = MySQLExecutionContext_mysqldb + statement_compiler = MySQLCompiler_mysqldb + preparer = MySQLIdentifierPreparer + + def __init__(self, **kwargs): + super(MySQLDialect_mysqldb, self).__init__(**kwargs) + self._mysql_dbapi_version = ( + self._parse_dbapi_version(self.dbapi.__version__) + if self.dbapi is not None and hasattr(self.dbapi, "__version__") + else (0, 0, 0) + ) + + def _parse_dbapi_version(self, version): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) + if m: + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) + else: + return (0, 0, 0) + + @util.langhelpers.memoized_property + def supports_server_side_cursors(self): + try: + cursors = __import__("MySQLdb.cursors").cursors + self._sscursor = cursors.SSCursor + return True + except (ImportError, AttributeError): + return False + + @classmethod + def dbapi(cls): + return __import__("MySQLdb") + + def on_connect(self): + super_ = super(MySQLDialect_mysqldb, self).on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + charset_name = conn.character_set_name() + + if charset_name is not None: + cursor = conn.cursor() + cursor.execute("SET NAMES %s" % charset_name) + cursor.close() + + return on_connect + + def do_ping(self, dbapi_connection): + try: + dbapi_connection.ping(False) + except self.dbapi.Error as err: + if self.is_disconnect(err, dbapi_connection, None): + return False + else: + raise + else: + return True + + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + + def _check_unicode_returns(self, connection): + # work around issue fixed in + # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 + # specific issue w/ the utf8mb4_bin collation and unicode returns + + collation = connection.exec_driver_sql( + "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation"), + ) + ).scalar() + has_utf8mb4_bin = self.server_version_info > (5,) and collation + if has_utf8mb4_bin: + additional_tests = [ + sql.collate( + sql.cast( + sql.literal_column("'test collated returns'"), + TEXT(charset="utf8mb4"), + ), + "utf8mb4_bin", + ) + ] + else: + additional_tests = [] + return super(MySQLDialect_mysqldb, self)._check_unicode_returns( + connection, additional_tests + ) + + def create_connect_args(self, url, _translate_args=None): + if _translate_args is None: + _translate_args = dict( + database="db", username="user", password="passwd" + ) + + opts = url.translate_connect_args(**_translate_args) + opts.update(url.query) + + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "read_timeout", int) + util.coerce_kw_type(opts, "write_timeout", int) + util.coerce_kw_type(opts, "client_flag", int) + util.coerce_kw_type(opts, "local_infile", int) + # Note: using either of the below will cause all strings to be + # returned as Unicode, both in raw SQL operations and with column + # types like String and MSString. + util.coerce_kw_type(opts, "use_unicode", bool) + util.coerce_kw_type(opts, "charset", str) + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + + ssl = {} + keys = [ + ("ssl_ca", str), + ("ssl_key", str), + ("ssl_cert", str), + ("ssl_capath", str), + ("ssl_cipher", str), + ("ssl_check_hostname", bool), + ] + for key, kw_type in keys: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], kw_type) + del opts[key] + if ssl: + opts["ssl"] = ssl + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get("client_flag", 0) + + client_flag_found_rows = self._found_rows_client_flag() + if client_flag_found_rows is not None: + client_flag |= client_flag_found_rows + opts["client_flag"] = client_flag + return [[], opts] + + def _found_rows_client_flag(self): + if self.dbapi is not None: + try: + CLIENT_FLAGS = __import__( + self.dbapi.__name__ + ".constants.CLIENT" + ).constants.CLIENT + except (AttributeError, ImportError): + return None + else: + return CLIENT_FLAGS.FOUND_ROWS + else: + return None + + def _extract_error_code(self, exception): + return exception.args[0] + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + try: + # note: the SQL here would be + # "SHOW VARIABLES LIKE 'character_set%%'" + cset_name = connection.connection.character_set_name + except AttributeError: + util.warn( + "No 'character_set_name' can be detected with " + "this MySQL-Python version; " + "please upgrade to a recent version of MySQL-Python. " + "Assuming latin1." + ) + return "latin1" + else: + return cset_name() + + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ] + ) + + def _set_isolation_level(self, connection, level): + if level == "AUTOCOMMIT": + connection.autocommit(True) + else: + connection.autocommit(False) + super(MySQLDialect_mysqldb, self)._set_isolation_level( + connection, level + ) + + +dialect = MySQLDialect_mysqldb diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py new file mode 100644 index 0000000..f6287dc --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -0,0 +1,273 @@ +# mysql/oursql.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: mysql+oursql + :name: OurSQL + :dbapi: oursql + :connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname> + :url: https://packages.python.org/oursql/ + +.. note:: + + The OurSQL MySQL dialect is legacy and is no longer supported upstream, + and is **not tested as part of SQLAlchemy's continuous integration**. + The recommended MySQL dialects are mysqlclient and PyMySQL. + +.. deprecated:: 1.4 The OurSQL DBAPI is deprecated and will be removed + in a future version. Please use one of the supported DBAPIs to + connect to mysql. + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + + +""" + + +from .base import BIT +from .base import MySQLDialect +from .base import MySQLExecutionContext +from ... import types as sqltypes +from ... import util + + +class _oursqlBIT(BIT): + def result_processor(self, dialect, coltype): + """oursql already converts mysql bits, so.""" + + return None + + +class MySQLExecutionContext_oursql(MySQLExecutionContext): + @property + def plain_query(self): + return self.execution_options.get("_oursql_plain_query", False) + + +class MySQLDialect_oursql(MySQLDialect): + driver = "oursql" + supports_statement_cache = True + + if util.py2k: + supports_unicode_binds = True + supports_unicode_statements = True + + supports_native_decimal = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + execution_ctx_cls = MySQLExecutionContext_oursql + + colspecs = util.update_copy( + MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _oursqlBIT} + ) + + @classmethod + def dbapi(cls): + util.warn_deprecated( + "The OurSQL DBAPI is deprecated and will be removed " + "in a future version. Please use one of the supported DBAPIs to " + "connect to mysql.", + version="1.4", + ) + return __import__("oursql") + + def do_execute(self, cursor, statement, parameters, context=None): + """Provide an implementation of + *cursor.execute(statement, parameters)*.""" + + if context and context.plain_query: + cursor.execute(statement, plain_query=True) + else: + cursor.execute(statement, parameters) + + def do_begin(self, connection): + connection.cursor().execute("BEGIN", plain_query=True) + + def _xa_query(self, connection, query, xid): + if util.py2k: + arg = connection.connection._escape_string(xid) + else: + charset = self._connection_charset + arg = connection.connection._escape_string( + xid.encode(charset) + ).decode(charset) + arg = "'%s'" % arg + connection.execution_options(_oursql_plain_query=True).exec_driver_sql( + query % arg + ) + + # Because mysql is bad, these methods have to be + # reimplemented to use _PlainQuery. Basically, some queries + # refuse to return any data if they're run through + # the parameterized query API, or refuse to be parameterized + # in the first place. + def do_begin_twophase(self, connection, xid): + self._xa_query(connection, "XA BEGIN %s", xid) + + def do_prepare_twophase(self, connection, xid): + self._xa_query(connection, "XA END %s", xid) + self._xa_query(connection, "XA PREPARE %s", xid) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self._xa_query(connection, "XA END %s", xid) + self._xa_query(connection, "XA ROLLBACK %s", xid) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + self._xa_query(connection, "XA COMMIT %s", xid) + + # Q: why didn't we need all these "plain_query" overrides earlier ? + # am i on a newer/older version of OurSQL ? + def has_table(self, connection, table_name, schema=None): + return MySQLDialect.has_table( + self, + connection.connect().execution_options(_oursql_plain_query=True), + table_name, + schema, + ) + + def get_table_options(self, connection, table_name, schema=None, **kw): + return MySQLDialect.get_table_options( + self, + connection.connect().execution_options(_oursql_plain_query=True), + table_name, + schema=schema, + **kw + ) + + def get_columns(self, connection, table_name, schema=None, **kw): + return MySQLDialect.get_columns( + self, + connection.connect().execution_options(_oursql_plain_query=True), + table_name, + schema=schema, + **kw + ) + + def get_view_names(self, connection, schema=None, **kw): + return MySQLDialect.get_view_names( + self, + connection.connect().execution_options(_oursql_plain_query=True), + schema=schema, + **kw + ) + + def get_table_names(self, connection, schema=None, **kw): + return MySQLDialect.get_table_names( + self, + connection.connect().execution_options(_oursql_plain_query=True), + schema, + ) + + def get_schema_names(self, connection, **kw): + return MySQLDialect.get_schema_names( + self, + connection.connect().execution_options(_oursql_plain_query=True), + **kw + ) + + def initialize(self, connection): + return MySQLDialect.initialize( + self, connection.execution_options(_oursql_plain_query=True) + ) + + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): + return MySQLDialect._show_create_table( + self, + connection.connect(close_with_result=True).execution_options( + _oursql_plain_query=True + ), + table, + charset, + full_name, + ) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.ProgrammingError): + return ( + e.errno is None + and "cursor" not in e.args[1] + and e.args[1].endswith("closed") + ) + else: + return e.errno in (2006, 2013, 2014, 2045, 2055) + + def create_connect_args(self, url): + opts = url.translate_connect_args( + database="db", username="user", password="passwd" + ) + opts.update(url.query) + + util.coerce_kw_type(opts, "port", int) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "autoping", bool) + util.coerce_kw_type(opts, "raise_on_warnings", bool) + + util.coerce_kw_type(opts, "default_charset", bool) + if opts.pop("default_charset", False): + opts["charset"] = None + else: + util.coerce_kw_type(opts, "charset", str) + opts["use_unicode"] = opts.get("use_unicode", True) + util.coerce_kw_type(opts, "use_unicode", bool) + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + opts.setdefault("found_rows", True) + + ssl = {} + for key in [ + "ssl_ca", + "ssl_key", + "ssl_cert", + "ssl_capath", + "ssl_cipher", + ]: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], str) + del opts[key] + if ssl: + opts["ssl"] = ssl + + return [[], opts] + + def _extract_error_code(self, exception): + return exception.errno + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + return connection.connection.charset + + def _compat_fetchall(self, rp, charset=None): + """oursql isn't super-broken like MySQLdb, yaaay.""" + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + """oursql isn't super-broken like MySQLdb, yaaay.""" + return rp.fetchone() + + def _compat_first(self, rp, charset=None): + return rp.first() + + +dialect = MySQLDialect_oursql diff --git a/lib/sqlalchemy/dialects/mysql/provision.py b/lib/sqlalchemy/dialects/mysql/provision.py new file mode 100644 index 0000000..86aaa94 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/provision.py @@ -0,0 +1,78 @@ +from ... import exc +from ...testing.provision import configure_follower +from ...testing.provision import create_db +from ...testing.provision import drop_db +from ...testing.provision import generate_driver_url +from ...testing.provision import temp_table_keyword_args + + +@generate_driver_url.for_db("mysql", "mariadb") +def generate_driver_url(url, driver, query_str): + backend = url.get_backend_name() + + # NOTE: at the moment, tests are running mariadbconnector + # against both mariadb and mysql backends. if we want this to be + # limited, do the decision making here to reject a "mysql+mariadbconnector" + # URL. Optionally also re-enable the module level + # MySQLDialect_mariadbconnector.is_mysql flag as well, which must include + # a unit and/or functional test. + + # all the Jenkins tests have been running mysqlclient Python library + # built against mariadb client drivers for years against all MySQL / + # MariaDB versions going back to MySQL 5.6, currently they can talk + # to MySQL databases without problems. + + if backend == "mysql": + dialect_cls = url.get_dialect() + if dialect_cls._is_mariadb_from_url(url): + backend = "mariadb" + + new_url = url.set( + drivername="%s+%s" % (backend, driver) + ).update_query_string(query_str) + + try: + new_url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return new_url + + +@create_db.for_db("mysql", "mariadb") +def _mysql_create_db(cfg, eng, ident): + with eng.begin() as conn: + try: + _mysql_drop_db(cfg, conn, ident) + except Exception: + pass + + with eng.begin() as conn: + conn.exec_driver_sql( + "CREATE DATABASE %s CHARACTER SET utf8mb4" % ident + ) + conn.exec_driver_sql( + "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident + ) + conn.exec_driver_sql( + "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident + ) + + +@configure_follower.for_db("mysql", "mariadb") +def _mysql_configure_follower(config, ident): + config.test_schema = "%s_test_schema" % ident + config.test_schema_2 = "%s_test_schema_2" % ident + + +@drop_db.for_db("mysql", "mariadb") +def _mysql_drop_db(cfg, eng, ident): + with eng.begin() as conn: + conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident) + conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident) + conn.exec_driver_sql("DROP DATABASE %s" % ident) + + +@temp_table_keyword_args.for_db("mysql", "mariadb") +def _mysql_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py new file mode 100644 index 0000000..f620133 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -0,0 +1,98 @@ +# mysql/pymysql.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" + +.. dialect:: mysql+pymysql + :name: PyMySQL + :dbapi: pymysql + :connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>] + :url: https://pymysql.readthedocs.io/ + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + +.. _pymysql_ssl: + +SSL Connections +------------------ + +The PyMySQL DBAPI accepts the same SSL arguments as that of MySQLdb, +described at :ref:`mysqldb_ssl`. See that section for examples. + + +MySQL-Python Compatibility +-------------------------- + +The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver, +and targets 100% compatibility. Most behavioral notes for MySQL-python apply +to the pymysql driver as well. + +""" # noqa + +from .mysqldb import MySQLDialect_mysqldb +from ...util import langhelpers +from ...util import py3k + + +class MySQLDialect_pymysql(MySQLDialect_mysqldb): + driver = "pymysql" + supports_statement_cache = True + + description_encoding = None + + # generally, these two values should be both True + # or both False. PyMySQL unicode tests pass all the way back + # to 0.4 either way. See [ticket:3337] + supports_unicode_statements = True + supports_unicode_binds = True + + @langhelpers.memoized_property + def supports_server_side_cursors(self): + try: + cursors = __import__("pymysql.cursors").cursors + self._sscursor = cursors.SSCursor + return True + except (ImportError, AttributeError): + return False + + @classmethod + def dbapi(cls): + return __import__("pymysql") + + def create_connect_args(self, url, _translate_args=None): + if _translate_args is None: + _translate_args = dict(username="user") + return super(MySQLDialect_pymysql, self).create_connect_args( + url, _translate_args=_translate_args + ) + + def is_disconnect(self, e, connection, cursor): + if super(MySQLDialect_pymysql, self).is_disconnect( + e, connection, cursor + ): + return True + elif isinstance(e, self.dbapi.Error): + str_e = str(e).lower() + return ( + "already closed" in str_e or "connection was killed" in str_e + ) + else: + return False + + if py3k: + + def _extract_error_code(self, exception): + if isinstance(exception.args[0], Exception): + exception = exception.args[0] + return exception.args[0] + + +dialect = MySQLDialect_pymysql diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py new file mode 100644 index 0000000..bfa61f6 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -0,0 +1,136 @@ +# mysql/pyodbc.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" + + +.. dialect:: mysql+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: mysql+pyodbc://<username>:<password>@<dsnname> + :url: https://pypi.org/project/pyodbc/ + +.. note:: + + The PyODBC for MySQL dialect is **not tested as part of + SQLAlchemy's continuous integration**. + The recommended MySQL dialects are mysqlclient and PyMySQL. + However, if you want to use the mysql+pyodbc dialect and require + full support for ``utf8mb4`` characters (including supplementary + characters like emoji) be sure to use a current release of + MySQL Connector/ODBC and specify the "ANSI" (**not** "Unicode") + version of the driver in your DSN or connection string. + +Pass through exact pyodbc connection string:: + + import urllib + connection_string = ( + 'DRIVER=MySQL ODBC 8.0 ANSI Driver;' + 'SERVER=localhost;' + 'PORT=3307;' + 'DATABASE=mydb;' + 'UID=root;' + 'PWD=(whatever);' + 'charset=utf8mb4;' + ) + params = urllib.parse.quote_plus(connection_string) + connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params + +""" # noqa + +import re + +from .base import MySQLDialect +from .base import MySQLExecutionContext +from .types import TIME +from ... import exc +from ... import util +from ...connectors.pyodbc import PyODBCConnector +from ...sql.sqltypes import Time + + +class _pyodbcTIME(TIME): + def result_processor(self, dialect, coltype): + def process(value): + # pyodbc returns a datetime.time object; no need to convert + return value + + return process + + +class MySQLExecutionContext_pyodbc(MySQLExecutionContext): + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): + supports_statement_cache = True + colspecs = util.update_copy(MySQLDialect.colspecs, {Time: _pyodbcTIME}) + supports_unicode_statements = True + execution_ctx_cls = MySQLExecutionContext_pyodbc + + pyodbc_driver_name = "MySQL" + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + + # set this to None as _fetch_setting attempts to use it (None is OK) + self._connection_charset = None + try: + value = self._fetch_setting(connection, "character_set_client") + if value: + return value + except exc.DBAPIError: + pass + + util.warn( + "Could not detect the connection character set. " + "Assuming latin1." + ) + return "latin1" + + def _get_server_version_info(self, connection): + return MySQLDialect._get_server_version_info(self, connection) + + def _extract_error_code(self, exception): + m = re.compile(r"\((\d+)\)").search(str(exception.args)) + c = m.group(1) + if c: + return int(c) + else: + return None + + def on_connect(self): + super_ = super(MySQLDialect_pyodbc, self).on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + # declare Unicode encoding for pyodbc as per + # https://github.com/mkleehammer/pyodbc/wiki/Unicode + pyodbc_SQL_CHAR = 1 # pyodbc.SQL_CHAR + pyodbc_SQL_WCHAR = -8 # pyodbc.SQL_WCHAR + conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8") + conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8") + conn.setencoding(encoding="utf-8") + + return on_connect + + +dialect = MySQLDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py new file mode 100644 index 0000000..27394bb --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -0,0 +1,558 @@ +# mysql/reflection.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import re + +from .enumerated import ENUM +from .enumerated import SET +from .types import DATETIME +from .types import TIME +from .types import TIMESTAMP +from ... import log +from ... import types as sqltypes +from ... import util + + +class ReflectedState(object): + """Stores raw information about a SHOW CREATE TABLE statement.""" + + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.fk_constraints = [] + self.ck_constraints = [] + + +@log.class_logger +class MySQLTableDefinitionParser(object): + """Parses the results of a SHOW CREATE TABLE statement.""" + + def __init__(self, dialect, preparer): + self.dialect = dialect + self.preparer = preparer + self._prep_regexes() + + def parse(self, show_create, charset): + state = ReflectedState() + state.charset = charset + for line in re.split(r"\r?\n", show_create): + if line.startswith(" " + self.preparer.initial_quote): + self._parse_column(line, state) + # a regular table options line + elif line.startswith(") "): + self._parse_table_options(line, state) + # an ANSI-mode table options line + elif line == ")": + pass + elif line.startswith("CREATE "): + self._parse_table_name(line, state) + # Not present in real reflection, but may be if + # loading from a file. + elif not line: + pass + else: + type_, spec = self._parse_constraints(line) + if type_ is None: + util.warn("Unknown schema content: %r" % line) + elif type_ == "key": + state.keys.append(spec) + elif type_ == "fk_constraint": + state.fk_constraints.append(spec) + elif type_ == "ck_constraint": + state.ck_constraints.append(spec) + else: + pass + return state + + def _parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + :param line: A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + # NOTE: we may want to consider SHOW INDEX as the + # format of indexes in MySQL becomes more complex + spec["columns"] = self._parse_keyexprs(spec["columns"]) + if spec["version_sql"]: + m2 = self._re_key_version_sql.match(spec["version_sql"]) + if m2 and m2.groupdict()["parser"]: + spec["parser"] = m2.groupdict()["parser"] + if spec["parser"]: + spec["parser"] = self.preparer.unformat_identifiers( + spec["parser"] + )[0] + return "key", spec + + # FOREIGN KEY CONSTRAINT + m = self._re_fk_constraint.match(line) + if m: + spec = m.groupdict() + spec["table"] = self.preparer.unformat_identifiers(spec["table"]) + spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])] + spec["foreign"] = [ + c[0] for c in self._parse_keyexprs(spec["foreign"]) + ] + return "fk_constraint", spec + + # CHECK constraint + m = self._re_ck_constraint.match(line) + if m: + spec = m.groupdict() + return "ck_constraint", spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return "partition", line + + # No match. + return (None, line) + + def _parse_table_name(self, line, state): + """Extract the table name. + + :param line: The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if m: + state.table_name = cleanup(m.group("name")) + + def _parse_table_options(self, line, state): + """Build a dictionary of all reflected table-level options. + + :param line: The final line of SHOW CREATE TABLE output. + """ + + options = {} + + if not line or line == ")": + pass + + else: + rest_of_line = line[:] + for regex, cleanup in self._pr_options: + m = regex.search(rest_of_line) + if not m: + continue + directive, value = m.group("directive"), m.group("val") + if cleanup: + value = cleanup(value) + options[directive.lower()] = value + rest_of_line = regex.sub("", rest_of_line) + + for nope in ("auto_increment", "data directory", "index directory"): + options.pop(nope, None) + + for opt, val in options.items(): + state.table_options["%s_%s" % (self.dialect.name, opt)] = val + + def _parse_column(self, line, state): + """Extract column details. + + Falls back to a 'minimal support' variant if full parse fails. + + :param line: Any column-bearing line from SHOW CREATE TABLE + """ + + spec = None + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec["full"] = True + else: + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec["full"] = False + if not spec: + util.warn("Unknown column definition %r" % line) + return + if not spec["full"]: + util.warn("Incomplete reflection of column definition %r" % line) + + name, type_, args = spec["name"], spec["coltype"], spec["arg"] + + try: + col_type = self.dialect.ischema_names[type_] + except KeyError: + util.warn( + "Did not recognize type '%s' of column '%s'" % (type_, name) + ) + col_type = sqltypes.NullType + + # Column type positional arguments eg. varchar(32) + if args is None or args == "": + type_args = [] + elif args[0] == "'" and args[-1] == "'": + type_args = self._re_csv_str.findall(args) + else: + type_args = [int(v) for v in self._re_csv_int.findall(args)] + + # Column type keyword options + type_kw = {} + + if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)): + if type_args: + type_kw["fsp"] = type_args.pop(0) + + for kw in ("unsigned", "zerofill"): + if spec.get(kw, False): + type_kw[kw] = True + for kw in ("charset", "collate"): + if spec.get(kw, False): + type_kw[kw] = spec[kw] + if issubclass(col_type, (ENUM, SET)): + type_args = _strip_values(type_args) + + if issubclass(col_type, SET) and "" in type_args: + type_kw["retrieve_as_bitwise"] = True + + type_instance = col_type(*type_args, **type_kw) + + col_kw = {} + + # NOT NULL + col_kw["nullable"] = True + # this can be "NULL" in the case of TIMESTAMP + if spec.get("notnull", False) == "NOT NULL": + col_kw["nullable"] = False + + # AUTO_INCREMENT + if spec.get("autoincr", False): + col_kw["autoincrement"] = True + elif issubclass(col_type, sqltypes.Integer): + col_kw["autoincrement"] = False + + # DEFAULT + default = spec.get("default", None) + + if default == "NULL": + # eliminates the need to deal with this later. + default = None + + comment = spec.get("comment", None) + + if comment is not None: + comment = comment.replace("\\\\", "\\").replace("''", "'") + + sqltext = spec.get("generated") + if sqltext is not None: + computed = dict(sqltext=sqltext) + persisted = spec.get("persistence") + if persisted is not None: + computed["persisted"] = persisted == "STORED" + col_kw["computed"] = computed + + col_d = dict( + name=name, type=type_instance, default=default, comment=comment + ) + col_d.update(col_kw) + state.columns.append(col_d) + + def _describe_to_create(self, table_name, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. + + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. + + :param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. + """ + + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = [ + row[i] for i in (0, 1, 2, 4, 5) + ] + + line = [" "] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append("NOT NULL") + if default: + if "auto_increment" in default: + pass + elif col_type.startswith("timestamp") and default.startswith( + "C" + ): + line.append("DEFAULT") + line.append(default) + elif default == "NULL": + line.append("DEFAULT") + line.append(default) + else: + line.append("DEFAULT") + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) + + buffer.append(" ".join(line)) + + return "".join( + [ + ( + "CREATE TABLE %s (\n" + % self.preparer.quote_identifier(table_name) + ), + ",\n".join(buffer), + "\n) ", + ] + ) + + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" + + return self._re_keyexprs.findall(identifiers) + + def _prep_regexes(self): + """Pre-compile regular expressions.""" + + self._re_columns = [] + self._pr_options = [] + + _final = self.preparer.final_quote + + quotes = dict( + zip( + ("iq", "fq", "esc_fq"), + [ + re.escape(s) + for s in ( + self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final), + ) + ], + ) + ) + + self._pr_name = _pr_compile( + r"^CREATE (?:\w+ +)?TABLE +" + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes, + self.preparer._unescape_identifier, + ) + + # `col`,`col2`(32),`col3`(15) DESC + # + self._re_keyexprs = _re_compile( + r"(?:" + r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)" + r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes + ) + + # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' + self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27") + + # 123 or 123,456 + self._re_csv_int = _re_compile(r"\d+") + + # `colname` <type> [type opts] + # (NOT NULL | NULL) + # DEFAULT ('value' | CURRENT_TIMESTAMP...) + # COMMENT 'comment' + # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT) + # STORAGE (DISK|MEMORY) + self._re_column = _re_compile( + r" " + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"(?P<coltype>\w+)" + r"(?:\((?P<arg>(?:\d+|\d+,\d+|" + r"(?:'(?:''|[^'])*',?)+))\))?" + r"(?: +(?P<unsigned>UNSIGNED))?" + r"(?: +(?P<zerofill>ZEROFILL))?" + r"(?: +CHARACTER SET +(?P<charset>[\w_]+))?" + r"(?: +COLLATE +(?P<collate>[\w_]+))?" + r"(?: +(?P<notnull>(?:NOT )?NULL))?" + r"(?: +DEFAULT +(?P<default>" + r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+" + r"(?: +ON UPDATE [\-\w\.\(\)]+)?)" + r"))?" + r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\(" + r".*\))? ?(?P<persistence>VIRTUAL|STORED)?)?" + r"(?: +(?P<autoincr>AUTO_INCREMENT))?" + r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?" + r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?" + r"(?: +STORAGE +(?P<storage>\w+))?" + r"(?: +(?P<extra>.*))?" + r",?$" % quotes + ) + + # Fallback, try to parse as little as possible + self._re_column_loose = _re_compile( + r" " + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"(?P<coltype>\w+)" + r"(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?" + r".*?(?P<notnull>(?:NOT )NULL)?" % quotes + ) + + # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? + # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) + # KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */ + self._re_key = _re_compile( + r" " + r"(?:(?P<type>\S+) )?KEY" + r"(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?" + r"(?: +USING +(?P<using_pre>\S+))?" + r" +\((?P<columns>.+?)\)" + r"(?: +USING +(?P<using_post>\S+))?" + r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?" + r"(?: +WITH PARSER +(?P<parser>\S+))?" + r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?" + r"(?: +/\*(?P<version_sql>.+)\*/ *)?" + r",?$" % quotes + ) + + # https://forums.mysql.com/read.php?20,567102,567111#msg-567111 + # It means if the MySQL version >= \d+, execute what's in the comment + self._re_key_version_sql = _re_compile( + r"\!\d+ " r"(?: *WITH PARSER +(?P<parser>\S+) *)?" + ) + + # CONSTRAINT `name` FOREIGN KEY (`local_col`) + # REFERENCES `remote` (`remote_col`) + # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE + # ON DELETE CASCADE ON UPDATE RESTRICT + # + # unique constraints come back as KEYs + kw = quotes.copy() + kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION" + self._re_fk_constraint = _re_compile( + r" " + r"CONSTRAINT +" + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"FOREIGN KEY +" + r"\((?P<local>[^\)]+?)\) REFERENCES +" + r"(?P<table>%(iq)s[^%(fq)s]+%(fq)s" + r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +" + r"\((?P<foreign>[^\)]+?)\)" + r"(?: +(?P<match>MATCH \w+))?" + r"(?: +ON DELETE (?P<ondelete>%(on)s))?" + r"(?: +ON UPDATE (?P<onupdate>%(on)s))?" % kw + ) + + # CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)' + # testing on MariaDB 10.2 shows that the CHECK constraint + # is returned on a line by itself, so to match without worrying + # about parenthesis in the expression we go to the end of the line + self._re_ck_constraint = _re_compile( + r" " + r"CONSTRAINT +" + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"CHECK +" + r"\((?P<sqltext>.+)\),?" % kw + ) + + # PARTITION + # + # punt! + self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)") + + # Table-level options (COLLATE, ENGINE, etc.) + # Do the string options first, since they have quoted + # strings we need to get rid of. + for option in _options_of_type_string: + self._add_option_string(option) + + for option in ( + "ENGINE", + "TYPE", + "AUTO_INCREMENT", + "AVG_ROW_LENGTH", + "CHARACTER SET", + "DEFAULT CHARSET", + "CHECKSUM", + "COLLATE", + "DELAY_KEY_WRITE", + "INSERT_METHOD", + "MAX_ROWS", + "MIN_ROWS", + "PACK_KEYS", + "ROW_FORMAT", + "KEY_BLOCK_SIZE", + ): + self._add_option_word(option) + + self._add_option_regex("UNION", r"\([^\)]+\)") + self._add_option_regex("TABLESPACE", r".*? STORAGE DISK") + self._add_option_regex( + "RAID_TYPE", + r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+", + ) + + _optional_equals = r"(?:\s*(?:=\s*)|\s+)" + + def _add_option_string(self, directive): + regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % ( + re.escape(directive), + self._optional_equals, + ) + self._pr_options.append( + _pr_compile( + regex, lambda v: v.replace("\\\\", "\\").replace("''", "'") + ) + ) + + def _add_option_word(self, directive): + regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % ( + re.escape(directive), + self._optional_equals, + ) + self._pr_options.append(_pr_compile(regex)) + + def _add_option_regex(self, directive, regex): + regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % ( + re.escape(directive), + self._optional_equals, + regex, + ) + self._pr_options.append(_pr_compile(regex)) + + +_options_of_type_string = ( + "COMMENT", + "DATA DIRECTORY", + "INDEX DIRECTORY", + "PASSWORD", + "CONNECTION", +) + + +def _pr_compile(regex, cleanup=None): + """Prepare a 2-tuple of compiled regex and callable.""" + + return (_re_compile(regex), cleanup) + + +def _re_compile(regex): + """Compile a string to regex, I and UNICODE.""" + + return re.compile(regex, re.I | re.UNICODE) + + +def _strip_values(values): + "Strip reflected values quotes" + strip_values = [] + for a in values: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_values.append(a) + return strip_values diff --git a/lib/sqlalchemy/dialects/mysql/reserved_words.py b/lib/sqlalchemy/dialects/mysql/reserved_words.py new file mode 100644 index 0000000..995168b --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/reserved_words.py @@ -0,0 +1,564 @@ +# mysql/reserved_words.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +# generated using: +# https://gist.github.com/kkirsche/4f31f2153ed7a3248be1ec44ca6ddbc9 +# +# https://mariadb.com/kb/en/reserved-words/ +# includes: Reserved Words, Oracle Mode (separate set unioned) +# excludes: Exceptions, Function Names +RESERVED_WORDS_MARIADB = { + "accessible", + "add", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "do_domain_ids", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "enclosed", + "escaped", + "except", + "exists", + "exit", + "explain", + "false", + "fetch", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "general", + "grant", + "group", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "ignore_domain_ids", + "ignore_server_ids", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "intersect", + "interval", + "into", + "is", + "iterate", + "join", + "key", + "keys", + "kill", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_heartbeat_period", + "master_ssl_verify_server_cert", + "match", + "maxvalue", + "mediumblob", + "mediumint", + "mediumtext", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "no_write_to_binlog", + "not", + "null", + "numeric", + "offset", + "on", + "optimize", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "over", + "page_checksum", + "parse_vcol_expr", + "partition", + "position", + "precision", + "primary", + "procedure", + "purge", + "range", + "read", + "read_write", + "reads", + "real", + "recursive", + "ref_system_id", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "resignal", + "restrict", + "return", + "returning", + "revoke", + "right", + "rlike", + "rows", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "signal", + "slow", + "smallint", + "spatial", + "specific", + "sql", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "sqlexception", + "sqlstate", + "sqlwarning", + "ssl", + "starting", + "stats_auto_recalc", + "stats_persistent", + "stats_sample_pages", + "straight_join", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "when", + "where", + "while", + "window", + "with", + "write", + "xor", + "year_month", + "zerofill", +}.union( + { + "body", + "elsif", + "goto", + "history", + "others", + "package", + "period", + "raise", + "rowtype", + "system", + "system_time", + "versioning", + "without", + } +) + +# https://dev.mysql.com/doc/refman/8.0/en/keywords.html +# https://dev.mysql.com/doc/refman/5.7/en/keywords.html +# https://dev.mysql.com/doc/refman/5.6/en/keywords.html +# includes: MySQL x.0 Keywords and Reserved Words +# excludes: MySQL x.0 New Keywords and Reserved Words, +# MySQL x.0 Removed Keywords and Reserved Words +RESERVED_WORDS_MYSQL = { + "accessible", + "add", + "admin", + "all", + "alter", + "analyze", + "and", + "array", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "cube", + "cume_dist", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "dense_rank", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "empty", + "enclosed", + "escaped", + "except", + "exists", + "exit", + "explain", + "false", + "fetch", + "first_value", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "function", + "general", + "generated", + "get", + "get_master_public_key", + "grant", + "group", + "grouping", + "groups", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "ignore_server_ids", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "interval", + "into", + "io_after_gtids", + "io_before_gtids", + "is", + "iterate", + "join", + "json_table", + "key", + "keys", + "kill", + "lag", + "last_value", + "lateral", + "lead", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_bind", + "master_heartbeat_period", + "master_ssl_verify_server_cert", + "match", + "maxvalue", + "mediumblob", + "mediumint", + "mediumtext", + "member", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "no_write_to_binlog", + "not", + "nth_value", + "ntile", + "null", + "numeric", + "of", + "on", + "optimize", + "optimizer_costs", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "over", + "parse_gcol_expr", + "partition", + "percent_rank", + "persist", + "persist_only", + "precision", + "primary", + "procedure", + "purge", + "range", + "rank", + "read", + "read_write", + "reads", + "real", + "recursive", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "resignal", + "restrict", + "return", + "revoke", + "right", + "rlike", + "role", + "row", + "row_number", + "rows", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "signal", + "slow", + "smallint", + "spatial", + "specific", + "sql", + "sql_after_gtids", + "sql_before_gtids", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "sqlexception", + "sqlstate", + "sqlwarning", + "ssl", + "starting", + "stored", + "straight_join", + "system", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "virtual", + "when", + "where", + "while", + "window", + "with", + "write", + "xor", + "year_month", + "zerofill", +} diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py new file mode 100644 index 0000000..b81ee95 --- /dev/null +++ b/lib/sqlalchemy/dialects/mysql/types.py @@ -0,0 +1,773 @@ +# mysql/types.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import datetime + +from ... import exc +from ... import types as sqltypes +from ... import util + + +class _NumericType(object): + """Base for MySQL numeric types. + + This is the base both for NUMERIC as well as INTEGER, hence + it's a mixin. + + """ + + def __init__(self, unsigned=False, zerofill=False, **kw): + self.unsigned = unsigned + self.zerofill = zerofill + super(_NumericType, self).__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_NumericType, sqltypes.Numeric] + ) + + +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + if isinstance(self, (REAL, DOUBLE)) and ( + (precision is None and scale is not None) + or (precision is not None and scale is None) + ): + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether." + ) + super(_FloatType, self).__init__( + precision=precision, asdecimal=asdecimal, **kw + ) + self.scale = scale + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_FloatType, _NumericType, sqltypes.Float] + ) + + +class _IntegerType(_NumericType, sqltypes.Integer): + def __init__(self, display_width=None, **kw): + self.display_width = display_width + super(_IntegerType, self).__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer] + ) + + +class _StringType(sqltypes.String): + """Base for MySQL string types.""" + + def __init__( + self, + charset=None, + collation=None, + ascii=False, # noqa + binary=False, + unicode=False, + national=False, + **kw + ): + self.charset = charset + + # allow collate= or collation= + kw.setdefault("collation", kw.pop("collate", collation)) + + self.ascii = ascii + self.unicode = unicode + self.binary = binary + self.national = national + super(_StringType, self).__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_StringType, sqltypes.String] + ) + + +class _MatchType(sqltypes.Float, sqltypes.MatchType): + def __init__(self, **kw): + # TODO: float arguments? + sqltypes.Float.__init__(self) + sqltypes.MatchType.__init__(self) + + +class NUMERIC(_NumericType, sqltypes.NUMERIC): + """MySQL NUMERIC type.""" + + __visit_name__ = "NUMERIC" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a NUMERIC. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(NUMERIC, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class DECIMAL(_NumericType, sqltypes.DECIMAL): + """MySQL DECIMAL type.""" + + __visit_name__ = "DECIMAL" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DECIMAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(DECIMAL, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class DOUBLE(_FloatType): + """MySQL DOUBLE type.""" + + __visit_name__ = "DOUBLE" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DOUBLE. + + .. note:: + + The :class:`.DOUBLE` type by default converts from float + to Decimal, using a truncation that defaults to 10 digits. + Specify either ``scale=n`` or ``decimal_return_scale=n`` in order + to change this scale, or ``asdecimal=False`` to return values + directly as Python floating points. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(DOUBLE, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class REAL(_FloatType, sqltypes.REAL): + """MySQL REAL type.""" + + __visit_name__ = "REAL" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a REAL. + + .. note:: + + The :class:`.REAL` type by default converts from float + to Decimal, using a truncation that defaults to 10 digits. + Specify either ``scale=n`` or ``decimal_return_scale=n`` in order + to change this scale, or ``asdecimal=False`` to return values + directly as Python floating points. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(REAL, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class FLOAT(_FloatType, sqltypes.FLOAT): + """MySQL FLOAT type.""" + + __visit_name__ = "FLOAT" + + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): + """Construct a FLOAT. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(FLOAT, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + def bind_processor(self, dialect): + return None + + +class INTEGER(_IntegerType, sqltypes.INTEGER): + """MySQL INTEGER type.""" + + __visit_name__ = "INTEGER" + + def __init__(self, display_width=None, **kw): + """Construct an INTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(INTEGER, self).__init__(display_width=display_width, **kw) + + +class BIGINT(_IntegerType, sqltypes.BIGINT): + """MySQL BIGINTEGER type.""" + + __visit_name__ = "BIGINT" + + def __init__(self, display_width=None, **kw): + """Construct a BIGINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(BIGINT, self).__init__(display_width=display_width, **kw) + + +class MEDIUMINT(_IntegerType): + """MySQL MEDIUMINTEGER type.""" + + __visit_name__ = "MEDIUMINT" + + def __init__(self, display_width=None, **kw): + """Construct a MEDIUMINTEGER + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(MEDIUMINT, self).__init__(display_width=display_width, **kw) + + +class TINYINT(_IntegerType): + """MySQL TINYINT type.""" + + __visit_name__ = "TINYINT" + + def __init__(self, display_width=None, **kw): + """Construct a TINYINT. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(TINYINT, self).__init__(display_width=display_width, **kw) + + +class SMALLINT(_IntegerType, sqltypes.SMALLINT): + """MySQL SMALLINTEGER type.""" + + __visit_name__ = "SMALLINT" + + def __init__(self, display_width=None, **kw): + """Construct a SMALLINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super(SMALLINT, self).__init__(display_width=display_width, **kw) + + +class BIT(sqltypes.TypeEngine): + """MySQL BIT type. + + This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater + for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a + MSTinyInteger() type. + + """ + + __visit_name__ = "BIT" + + def __init__(self, length=None): + """Construct a BIT. + + :param length: Optional, number of bits. + + """ + self.length = length + + def result_processor(self, dialect, coltype): + """Convert a MySQL's 64 bit, variable length binary string to a long. + + TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector + already do this, so this logic should be moved to those dialects. + + """ + + def process(value): + if value is not None: + v = 0 + for i in value: + if not isinstance(i, int): + i = ord(i) # convert byte to int on Python 2 + v = v << 8 | i + return v + return value + + return process + + +class TIME(sqltypes.TIME): + """MySQL TIME type.""" + + __visit_name__ = "TIME" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL TIME type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the TIME type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super(TIME, self).__init__(timezone=timezone) + self.fsp = fsp + + def result_processor(self, dialect, coltype): + time = datetime.time + + def process(value): + # convert from a timedelta value + if value is not None: + microseconds = value.microseconds + seconds = value.seconds + minutes = seconds // 60 + return time( + minutes // 60, + minutes % 60, + seconds - minutes * 60, + microsecond=microseconds, + ) + else: + return None + + return process + + +class TIMESTAMP(sqltypes.TIMESTAMP): + """MySQL TIMESTAMP type.""" + + __visit_name__ = "TIMESTAMP" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL TIMESTAMP type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6.4 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the TIMESTAMP type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super(TIMESTAMP, self).__init__(timezone=timezone) + self.fsp = fsp + + +class DATETIME(sqltypes.DATETIME): + """MySQL DATETIME type.""" + + __visit_name__ = "DATETIME" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL DATETIME type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6.4 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the DATETIME type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super(DATETIME, self).__init__(timezone=timezone) + self.fsp = fsp + + +class YEAR(sqltypes.TypeEngine): + """MySQL YEAR type, for single byte storage of years 1901-2155.""" + + __visit_name__ = "YEAR" + + def __init__(self, display_width=None): + self.display_width = display_width + + +class TEXT(_StringType, sqltypes.TEXT): + """MySQL TEXT type, for text up to 2^16 characters.""" + + __visit_name__ = "TEXT" + + def __init__(self, length=None, **kw): + """Construct a TEXT. + + :param length: Optional, if provided the server may optimize storage + by substituting the smallest TEXT type sufficient to store + ``length`` characters. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(TEXT, self).__init__(length=length, **kw) + + +class TINYTEXT(_StringType): + """MySQL TINYTEXT type, for text up to 2^8 characters.""" + + __visit_name__ = "TINYTEXT" + + def __init__(self, **kwargs): + """Construct a TINYTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(TINYTEXT, self).__init__(**kwargs) + + +class MEDIUMTEXT(_StringType): + """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" + + __visit_name__ = "MEDIUMTEXT" + + def __init__(self, **kwargs): + """Construct a MEDIUMTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(MEDIUMTEXT, self).__init__(**kwargs) + + +class LONGTEXT(_StringType): + """MySQL LONGTEXT type, for text up to 2^32 characters.""" + + __visit_name__ = "LONGTEXT" + + def __init__(self, **kwargs): + """Construct a LONGTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(LONGTEXT, self).__init__(**kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MySQL VARCHAR type, for variable-length character data.""" + + __visit_name__ = "VARCHAR" + + def __init__(self, length=None, **kwargs): + """Construct a VARCHAR. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super(VARCHAR, self).__init__(length=length, **kwargs) + + +class CHAR(_StringType, sqltypes.CHAR): + """MySQL CHAR type, for fixed-length character data.""" + + __visit_name__ = "CHAR" + + def __init__(self, length=None, **kwargs): + """Construct a CHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + super(CHAR, self).__init__(length=length, **kwargs) + + @classmethod + def _adapt_string_for_cast(self, type_): + # copy the given string type into a CHAR + # for the purposes of rendering a CAST expression + type_ = sqltypes.to_instance(type_) + if isinstance(type_, sqltypes.CHAR): + return type_ + elif isinstance(type_, _StringType): + return CHAR( + length=type_.length, + charset=type_.charset, + collation=type_.collation, + ascii=type_.ascii, + binary=type_.binary, + unicode=type_.unicode, + national=False, # not supported in CAST + ) + else: + return CHAR(length=type_.length) + + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MySQL NVARCHAR type. + + For variable-length character data in the server's configured national + character set. + """ + + __visit_name__ = "NVARCHAR" + + def __init__(self, length=None, **kwargs): + """Construct an NVARCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs["national"] = True + super(NVARCHAR, self).__init__(length=length, **kwargs) + + +class NCHAR(_StringType, sqltypes.NCHAR): + """MySQL NCHAR type. + + For fixed-length character data in the server's configured national + character set. + """ + + __visit_name__ = "NCHAR" + + def __init__(self, length=None, **kwargs): + """Construct an NCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs["national"] = True + super(NCHAR, self).__init__(length=length, **kwargs) + + +class TINYBLOB(sqltypes._Binary): + """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" + + __visit_name__ = "TINYBLOB" + + +class MEDIUMBLOB(sqltypes._Binary): + """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" + + __visit_name__ = "MEDIUMBLOB" + + +class LONGBLOB(sqltypes._Binary): + """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" + + __visit_name__ = "LONGBLOB" diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py new file mode 100644 index 0000000..c83e057 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -0,0 +1,58 @@ +# oracle/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from . import base # noqa +from . import cx_oracle # noqa +from .base import BFILE +from .base import BINARY_DOUBLE +from .base import BINARY_FLOAT +from .base import BLOB +from .base import CHAR +from .base import CLOB +from .base import DATE +from .base import DOUBLE_PRECISION +from .base import FLOAT +from .base import INTERVAL +from .base import LONG +from .base import NCHAR +from .base import NCLOB +from .base import NUMBER +from .base import NVARCHAR +from .base import NVARCHAR2 +from .base import RAW +from .base import ROWID +from .base import TIMESTAMP +from .base import VARCHAR +from .base import VARCHAR2 + + +base.dialect = dialect = cx_oracle.dialect + +__all__ = ( + "VARCHAR", + "NVARCHAR", + "CHAR", + "NCHAR", + "DATE", + "NUMBER", + "BLOB", + "BFILE", + "CLOB", + "NCLOB", + "TIMESTAMP", + "RAW", + "FLOAT", + "DOUBLE_PRECISION", + "BINARY_DOUBLE", + "BINARY_FLOAT", + "LONG", + "dialect", + "INTERVAL", + "VARCHAR2", + "NVARCHAR2", + "ROWID", +) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py new file mode 100644 index 0000000..77f0dbd --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -0,0 +1,2522 @@ +# oracle/base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" +.. dialect:: oracle + :name: Oracle + :full_support: 11.2, 18c + :normal_support: 11+ + :best_effort: 8+ + + +Auto Increment Behavior +----------------------- + +SQLAlchemy Table objects which include integer primary keys are usually +assumed to have "autoincrementing" behavior, meaning they can generate their +own primary key values upon INSERT. For use within Oracle, two options are +available, which are the use of IDENTITY columns (Oracle 12 and above only) +or the association of a SEQUENCE with the column. + +Specifying GENERATED AS IDENTITY (Oracle 12 and above) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Starting from version 12 Oracle can make use of identity columns using +the :class:`_sql.Identity` to specify the autoincrementing behavior:: + + t = Table('mytable', metadata, + Column('id', Integer, Identity(start=3), primary_key=True), + Column(...), ... + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE mytable ( + id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 3), + ..., + PRIMARY KEY (id) + ) + +The :class:`_schema.Identity` object support many options to control the +"autoincrementing" behavior of the column, like the starting value, the +incrementing value, etc. +In addition to the standard options, Oracle supports setting +:paramref:`_schema.Identity.always` to ``None`` to use the default +generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports +setting :paramref:`_schema.Identity.on_null` to ``True`` to specify ON NULL +in conjunction with a 'BY DEFAULT' identity column. + +Using a SEQUENCE (all Oracle versions) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Older version of Oracle had no "autoincrement" +feature, SQLAlchemy relies upon sequences to produce these values. With the +older Oracle versions, *a sequence must always be explicitly specified to +enable autoincrement*. This is divergent with the majority of documentation +examples which assume the usage of an autoincrement-capable database. To +specify sequences, use the sqlalchemy.schema.Sequence object which is passed +to a Column construct:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + Column(...), ... + ) + +This step is also required when using table reflection, i.e. autoload_with=engine:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq'), primary_key=True), + autoload_with=engine + ) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the option of an autoincrementing + column. + +.. _oracle_isolation_level: + +Transaction Isolation Level / Autocommit +---------------------------------------- + +The Oracle database supports "READ COMMITTED" and "SERIALIZABLE" modes of +isolation. The AUTOCOMMIT isolation level is also supported by the cx_Oracle +dialect. + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="AUTOCOMMIT" + ) + +For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle dialect sets the +level at the session level using ``ALTER SESSION``, which is reverted back +to its default setting when the connection is returned to the connection +pool. + +Valid values for ``isolation_level`` include: + +* ``READ COMMITTED`` +* ``AUTOCOMMIT`` +* ``SERIALIZABLE`` + +.. note:: The implementation for the + :meth:`_engine.Connection.get_isolation_level` method as implemented by the + Oracle dialect necessarily forces the start of a transaction using the + Oracle LOCAL_TRANSACTION_ID function; otherwise no level is normally + readable. + + Additionally, the :meth:`_engine.Connection.get_isolation_level` method will + raise an exception if the ``v$transaction`` view is not available due to + permissions or other reasons, which is a common occurrence in Oracle + installations. + + The cx_Oracle dialect attempts to call the + :meth:`_engine.Connection.get_isolation_level` method when the dialect makes + its first connection to the database in order to acquire the + "default"isolation level. This default level is necessary so that the level + can be reset on a connection after it has been temporarily modified using + :meth:`_engine.Connection.execution_options` method. In the common event + that the :meth:`_engine.Connection.get_isolation_level` method raises an + exception due to ``v$transaction`` not being readable as well as any other + database-related failure, the level is assumed to be "READ COMMITTED". No + warning is emitted for this initial first-connect condition as it is + expected to be a common restriction on Oracle databases. + +.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_oracle dialect + as well as the notion of a default isolation level + +.. versionadded:: 1.3.21 Added support for SERIALIZABLE as well as live + reading of the isolation level. + +.. versionchanged:: 1.3.22 In the event that the default isolation + level cannot be read due to permissions on the v$transaction view as + is common in Oracle installations, the default isolation level is hardcoded + to "READ COMMITTED" which was the behavior prior to 1.3.21. + +.. seealso:: + + :ref:`dbapi_autocommit` + +Identifier Casing +----------------- + +In Oracle, the data dictionary represents all case insensitive identifier +names using UPPERCASE text. SQLAlchemy on the other hand considers an +all-lower case identifier name to be case insensitive. The Oracle dialect +converts all case insensitive identifiers to and from those two formats during +schema level communication, such as reflection of tables and indexes. Using +an UPPERCASE name on the SQLAlchemy side indicates a case sensitive +identifier, and SQLAlchemy will quote the name - this will cause mismatches +against data dictionary data received from Oracle, so unless identifier names +have been truly created as case sensitive (i.e. using quoted names), all +lowercase names should be used on the SQLAlchemy side. + +.. _oracle_max_identifier_lengths: + +Max Identifier Lengths +---------------------- + +Oracle has changed the default max identifier length as of Oracle Server +version 12.2. Prior to this version, the length was 30, and for 12.2 and +greater it is now 128. This change impacts SQLAlchemy in the area of +generated SQL label names as well as the generation of constraint names, +particularly in the case where the constraint naming convention feature +described at :ref:`constraint_naming_conventions` is being used. + +To assist with this change and others, Oracle includes the concept of a +"compatibility" version, which is a version number that is independent of the +actual server version in order to assist with migration of Oracle databases, +and may be configured within the Oracle server itself. This compatibility +version is retrieved using the query ``SELECT value FROM v$parameter WHERE +name = 'compatible';``. The SQLAlchemy Oracle dialect, when tasked with +determining the default max identifier length, will attempt to use this query +upon first connect in order to determine the effective compatibility version of +the server, which determines what the maximum allowed identifier length is for +the server. If the table is not available, the server version information is +used instead. + +As of SQLAlchemy 1.4, the default max identifier length for the Oracle dialect +is 128 characters. Upon first connect, the compatibility version is detected +and if it is less than Oracle version 12.2, the max identifier length is +changed to be 30 characters. In all cases, setting the +:paramref:`_sa.create_engine.max_identifier_length` parameter will bypass this +change and the value given will be used as is:: + + engine = create_engine( + "oracle+cx_oracle://scott:tiger@oracle122", + max_identifier_length=30) + +The maximum identifier length comes into play both when generating anonymized +SQL labels in SELECT statements, but more crucially when generating constraint +names from a naming convention. It is this area that has created the need for +SQLAlchemy to change this default conservatively. For example, the following +naming convention produces two very different constraint names based on the +identifier length:: + + from sqlalchemy import Column + from sqlalchemy import Index + from sqlalchemy import Integer + from sqlalchemy import MetaData + from sqlalchemy import Table + from sqlalchemy.dialects import oracle + from sqlalchemy.schema import CreateIndex + + m = MetaData(naming_convention={"ix": "ix_%(column_0N_name)s"}) + + t = Table( + "t", + m, + Column("some_column_name_1", Integer), + Column("some_column_name_2", Integer), + Column("some_column_name_3", Integer), + ) + + ix = Index( + None, + t.c.some_column_name_1, + t.c.some_column_name_2, + t.c.some_column_name_3, + ) + + oracle_dialect = oracle.dialect(max_identifier_length=30) + print(CreateIndex(ix).compile(dialect=oracle_dialect)) + +With an identifier length of 30, the above CREATE INDEX looks like:: + + CREATE INDEX ix_some_column_name_1s_70cd ON t + (some_column_name_1, some_column_name_2, some_column_name_3) + +However with length=128, it becomes:: + + CREATE INDEX ix_some_column_name_1some_column_name_2some_column_name_3 ON t + (some_column_name_1, some_column_name_2, some_column_name_3) + +Applications which have run versions of SQLAlchemy prior to 1.4 on an Oracle +server version 12.2 or greater are therefore subject to the scenario of a +database migration that wishes to "DROP CONSTRAINT" on a name that was +previously generated with the shorter length. This migration will fail when +the identifier length is changed without the name of the index or constraint +first being adjusted. Such applications are strongly advised to make use of +:paramref:`_sa.create_engine.max_identifier_length` +in order to maintain control +of the generation of truncated names, and to fully review and test all database +migrations in a staging environment when changing this value to ensure that the +impact of this change has been mitigated. + +.. versionchanged:: 1.4 the default max_identifier_length for Oracle is 128 + characters, which is adjusted down to 30 upon first connect if an older + version of Oracle server (compatibility version < 12.2) is detected. + + +LIMIT/OFFSET/FETCH Support +-------------------------- + +Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` currently +use an emulated approach for LIMIT / OFFSET based on window functions, which +involves creation of a subquery using ``ROW_NUMBER`` that is prone to +performance issues as well as SQL construction issues for complex statements. +However, this approach is supported by all Oracle versions. See notes below. + +When using Oracle 12c and above, use the :meth:`_sql.Select.fetch` method +instead; this will render the more modern +``FETCH FIRST N ROW / OFFSET N ROWS`` syntax. + +Notes on LIMIT / OFFSET emulation (when fetch() method cannot be used) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If using :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset`, +or with the ORM the :meth:`_orm.Query.limit` and :meth:`_orm.Query.offset` methods, +and the :meth:`_sql.Select.fetch` method **cannot** be used instead, the following +notes apply: + +* SQLAlchemy currently makes use of ROWNUM to achieve + LIMIT/OFFSET; the exact methodology is taken from + https://blogs.oracle.com/oraclemagazine/on-rownum-and-limiting-results . + +* the "FIRST_ROWS()" optimization keyword is not used by default. To enable + the usage of this optimization directive, specify ``optimize_limits=True`` + to :func:`_sa.create_engine`. + + .. versionchanged:: 1.4 + The Oracle dialect renders limit/offset integer values using a "post + compile" scheme which renders the integer directly before passing the + statement to the cursor for execution. The ``use_binds_for_limits`` flag + no longer has an effect. + + .. seealso:: + + :ref:`change_4808`. + +* A future release may use ``FETCH FIRST N ROW / OFFSET N ROWS`` automatically + when :meth:`_sql.Select.limit`, :meth:`_sql.Select.offset`, :meth:`_orm.Query.limit`, + :meth:`_orm.Query.offset` are used. + +.. _oracle_returning: + +RETURNING Support +----------------- + +The Oracle database supports a limited form of RETURNING, in order to retrieve +result sets of matched rows from INSERT, UPDATE and DELETE statements. +Oracle's RETURNING..INTO syntax only supports one row being returned, as it +relies upon OUT parameters in order to function. In addition, supported +DBAPIs have further limitations (see :ref:`cx_oracle_returning`). + +SQLAlchemy's "implicit returning" feature, which employs RETURNING within an +INSERT and sometimes an UPDATE statement in order to fetch newly generated +primary key values and other SQL defaults and expressions, is normally enabled +on the Oracle backend. By default, "implicit returning" typically only +fetches the value of a single ``nextval(some_seq)`` expression embedded into +an INSERT in order to increment a sequence within an INSERT statement and get +the value back at the same time. To disable this feature across the board, +specify ``implicit_returning=False`` to :func:`_sa.create_engine`:: + + engine = create_engine("oracle://scott:tiger@dsn", + implicit_returning=False) + +Implicit returning can also be disabled on a table-by-table basis as a table +option:: + + # Core Table + my_table = Table("my_table", metadata, ..., implicit_returning=False) + + + # declarative + class MyClass(Base): + __tablename__ = 'my_table' + __table_args__ = {"implicit_returning": False} + +.. seealso:: + + :ref:`cx_oracle_returning` - additional cx_oracle-specific restrictions on + implicit returning. + +ON UPDATE CASCADE +----------------- + +Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based +solution is available at +https://asktom.oracle.com/tkyte/update_cascade/index.html . + +When using the SQLAlchemy ORM, the ORM has limited ability to manually issue +cascading updates - specify ForeignKey objects using the +"deferrable=True, initially='deferred'" keyword arguments, +and specify "passive_updates=False" on each relationship(). + +Oracle 8 Compatibility +---------------------- + +When Oracle 8 is detected, the dialect internally configures itself to the +following behaviors: + +* the use_ansi flag is set to False. This has the effect of converting all + JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN + makes use of Oracle's (+) operator. + +* the NVARCHAR2 and NCLOB datatypes are no longer generated as DDL when + the :class:`~sqlalchemy.types.Unicode` is used - VARCHAR2 and CLOB are + issued instead. This because these types don't seem to work correctly on + Oracle 8 even though they are available. The + :class:`~sqlalchemy.types.NVARCHAR` and + :class:`~sqlalchemy.dialects.oracle.NCLOB` types will always generate + NVARCHAR2 and NCLOB. + +* the "native unicode" mode is disabled when using cx_oracle, i.e. SQLAlchemy + encodes all Python unicode objects to "string" before passing in as bind + parameters. + +Synonym/DBLINK Reflection +------------------------- + +When using reflection with Table objects, the dialect can optionally search +for tables indicated by synonyms, either in local or remote schemas or +accessed over DBLINK, by passing the flag ``oracle_resolve_synonyms=True`` as +a keyword argument to the :class:`_schema.Table` construct:: + + some_table = Table('some_table', autoload_with=some_engine, + oracle_resolve_synonyms=True) + +When this flag is set, the given name (such as ``some_table`` above) will +be searched not just in the ``ALL_TABLES`` view, but also within the +``ALL_SYNONYMS`` view to see if this name is actually a synonym to another +name. If the synonym is located and refers to a DBLINK, the oracle dialect +knows how to locate the table's information using DBLINK syntax(e.g. +``@dblink``). + +``oracle_resolve_synonyms`` is accepted wherever reflection arguments are +accepted, including methods such as :meth:`_schema.MetaData.reflect` and +:meth:`_reflection.Inspector.get_columns`. + +If synonyms are not in use, this flag should be left disabled. + +.. _oracle_constraint_reflection: + +Constraint Reflection +--------------------- + +The Oracle dialect can return information about foreign key, unique, and +CHECK constraints, as well as indexes on tables. + +Raw information regarding these constraints can be acquired using +:meth:`_reflection.Inspector.get_foreign_keys`, +:meth:`_reflection.Inspector.get_unique_constraints`, +:meth:`_reflection.Inspector.get_check_constraints`, and +:meth:`_reflection.Inspector.get_indexes`. + +.. versionchanged:: 1.2 The Oracle dialect can now reflect UNIQUE and + CHECK constraints. + +When using reflection at the :class:`_schema.Table` level, the +:class:`_schema.Table` +will also include these constraints. + +Note the following caveats: + +* When using the :meth:`_reflection.Inspector.get_check_constraints` method, + Oracle + builds a special "IS NOT NULL" constraint for columns that specify + "NOT NULL". This constraint is **not** returned by default; to include + the "IS NOT NULL" constraints, pass the flag ``include_all=True``:: + + from sqlalchemy import create_engine, inspect + + engine = create_engine("oracle+cx_oracle://s:t@dsn") + inspector = inspect(engine) + all_check_constraints = inspector.get_check_constraints( + "some_table", include_all=True) + +* in most cases, when reflecting a :class:`_schema.Table`, + a UNIQUE constraint will + **not** be available as a :class:`.UniqueConstraint` object, as Oracle + mirrors unique constraints with a UNIQUE index in most cases (the exception + seems to be when two or more unique constraints represent the same columns); + the :class:`_schema.Table` will instead represent these using + :class:`.Index` + with the ``unique=True`` flag set. + +* Oracle creates an implicit index for the primary key of a table; this index + is **excluded** from all index results. + +* the list of columns reflected for an index will not include column names + that start with SYS_NC. + +Table names with SYSTEM/SYSAUX tablespaces +------------------------------------------- + +The :meth:`_reflection.Inspector.get_table_names` and +:meth:`_reflection.Inspector.get_temp_table_names` +methods each return a list of table names for the current engine. These methods +are also part of the reflection which occurs within an operation such as +:meth:`_schema.MetaData.reflect`. By default, +these operations exclude the ``SYSTEM`` +and ``SYSAUX`` tablespaces from the operation. In order to change this, the +default list of tablespaces excluded can be changed at the engine level using +the ``exclude_tablespaces`` parameter:: + + # exclude SYSAUX and SOME_TABLESPACE, but not SYSTEM + e = create_engine( + "oracle://scott:tiger@xe", + exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"]) + +.. versionadded:: 1.1 + +DateTime Compatibility +---------------------- + +Oracle has no datatype known as ``DATETIME``, it instead has only ``DATE``, +which can actually store a date and time value. For this reason, the Oracle +dialect provides a type :class:`_oracle.DATE` which is a subclass of +:class:`.DateTime`. This type has no special behavior, and is only +present as a "marker" for this type; additionally, when a database column +is reflected and the type is reported as ``DATE``, the time-supporting +:class:`_oracle.DATE` type is used. + +.. versionchanged:: 0.9.4 Added :class:`_oracle.DATE` to subclass + :class:`.DateTime`. This is a change as previous versions + would reflect a ``DATE`` column as :class:`_types.DATE`, which subclasses + :class:`.Date`. The only significance here is for schemes that are + examining the type of column for use in special Python translations or + for migrating schemas to other database backends. + +.. _oracle_table_options: + +Oracle Table Options +------------------------- + +The CREATE TABLE phrase supports the following options with Oracle +in conjunction with the :class:`_schema.Table` construct: + + +* ``ON COMMIT``:: + + Table( + "some_table", metadata, ..., + prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS') + +.. versionadded:: 1.0.0 + +* ``COMPRESS``:: + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=True) + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=6) + + The ``oracle_compress`` parameter accepts either an integer compression + level, or ``True`` to use the default compression level. + +.. versionadded:: 1.0.0 + +.. _oracle_index_options: + +Oracle Specific Index Options +----------------------------- + +Bitmap Indexes +~~~~~~~~~~~~~~ + +You can specify the ``oracle_bitmap`` parameter to create a bitmap index +instead of a B-tree index:: + + Index('my_index', my_table.c.data, oracle_bitmap=True) + +Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not +check for such limitations, only the database will. + +.. versionadded:: 1.0.0 + +Index compression +~~~~~~~~~~~~~~~~~ + +Oracle has a more efficient storage mode for indexes containing lots of +repeated values. Use the ``oracle_compress`` parameter to turn on key +compression:: + + Index('my_index', my_table.c.data, oracle_compress=True) + + Index('my_index', my_table.c.data1, my_table.c.data2, unique=True, + oracle_compress=1) + +The ``oracle_compress`` parameter accepts either an integer specifying the +number of prefix columns to compress, or ``True`` to use the default (all +columns for non-unique indexes, all but the last column for unique indexes). + +.. versionadded:: 1.0.0 + +""" # noqa + +from itertools import groupby +import re + +from ... import Computed +from ... import exc +from ... import schema as sa_schema +from ... import sql +from ... import util +from ...engine import default +from ...engine import reflection +from ...sql import compiler +from ...sql import expression +from ...sql import sqltypes +from ...sql import util as sql_util +from ...sql import visitors +from ...types import BLOB +from ...types import CHAR +from ...types import CLOB +from ...types import FLOAT +from ...types import INTEGER +from ...types import NCHAR +from ...types import NVARCHAR +from ...types import TIMESTAMP +from ...types import VARCHAR +from ...util import compat + +RESERVED_WORDS = set( + "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN " + "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED " + "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE " + "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE " + "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES " + "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS " + "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER " + "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR " + "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split() +) + +NO_ARG_FNS = set( + "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split() +) + + +class RAW(sqltypes._Binary): + __visit_name__ = "RAW" + + +OracleRaw = RAW + + +class NCLOB(sqltypes.Text): + __visit_name__ = "NCLOB" + + +class VARCHAR2(VARCHAR): + __visit_name__ = "VARCHAR2" + + +NVARCHAR2 = NVARCHAR + + +class NUMBER(sqltypes.Numeric, sqltypes.Integer): + __visit_name__ = "NUMBER" + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = bool(scale and scale > 0) + + super(NUMBER, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal + ) + + def adapt(self, impltype): + ret = super(NUMBER, self).adapt(impltype) + # leave a hint for the DBAPI handler + ret._is_oracle_number = True + return ret + + @property + def _type_affinity(self): + if bool(self.scale and self.scale > 0): + return sqltypes.Numeric + else: + return sqltypes.Integer + + +class DOUBLE_PRECISION(sqltypes.Float): + __visit_name__ = "DOUBLE_PRECISION" + + +class BINARY_DOUBLE(sqltypes.Float): + __visit_name__ = "BINARY_DOUBLE" + + +class BINARY_FLOAT(sqltypes.Float): + __visit_name__ = "BINARY_FLOAT" + + +class BFILE(sqltypes.LargeBinary): + __visit_name__ = "BFILE" + + +class LONG(sqltypes.Text): + __visit_name__ = "LONG" + + +class DATE(sqltypes.DateTime): + """Provide the oracle DATE type. + + This type has no special Python behavior, except that it subclasses + :class:`_types.DateTime`; this is to suit the fact that the Oracle + ``DATE`` type supports a time value. + + .. versionadded:: 0.9.4 + + """ + + __visit_name__ = "DATE" + + def _compare_type_affinity(self, other): + return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + __visit_name__ = "INTERVAL" + + def __init__(self, day_precision=None, second_precision=None): + """Construct an INTERVAL. + + Note that only DAY TO SECOND intervals are currently supported. + This is due to a lack of support for YEAR TO MONTH intervals + within available DBAPIs. + + :param day_precision: the day precision value. this is the number of + digits to store for the day field. Defaults to "2" + :param second_precision: the second precision value. this is the + number of digits to store for the fractional seconds field. + Defaults to "6". + + """ + self.day_precision = day_precision + self.second_precision = second_precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval( + native=True, + second_precision=self.second_precision, + day_precision=self.day_precision, + ) + + def coerce_compared_value(self, op, value): + return self + + +class ROWID(sqltypes.TypeEngine): + """Oracle ROWID type. + + When used in a cast() or similar, generates ROWID. + + """ + + __visit_name__ = "ROWID" + + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + +colspecs = { + sqltypes.Boolean: _OracleBoolean, + sqltypes.Interval: INTERVAL, + sqltypes.DateTime: DATE, +} + +ischema_names = { + "VARCHAR2": VARCHAR, + "NVARCHAR2": NVARCHAR, + "CHAR": CHAR, + "NCHAR": NCHAR, + "DATE": DATE, + "NUMBER": NUMBER, + "BLOB": BLOB, + "BFILE": BFILE, + "CLOB": CLOB, + "NCLOB": NCLOB, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP WITH TIME ZONE": TIMESTAMP, + "INTERVAL DAY TO SECOND": INTERVAL, + "RAW": RAW, + "FLOAT": FLOAT, + "DOUBLE PRECISION": DOUBLE_PRECISION, + "LONG": LONG, + "BINARY_DOUBLE": BINARY_DOUBLE, + "BINARY_FLOAT": BINARY_FLOAT, +} + + +class OracleTypeCompiler(compiler.GenericTypeCompiler): + # Note: + # Oracle DATE == DATETIME + # Oracle does not allow milliseconds in DATE + # Oracle does not support TIME columns + + def visit_datetime(self, type_, **kw): + return self.visit_DATE(type_, **kw) + + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) + + def visit_unicode(self, type_, **kw): + if self.dialect._use_nchar_for_unicode: + return self.visit_NVARCHAR2(type_, **kw) + else: + return self.visit_VARCHAR2(type_, **kw) + + def visit_INTERVAL(self, type_, **kw): + return "INTERVAL DAY%s TO SECOND%s" % ( + type_.day_precision is not None + and "(%d)" % type_.day_precision + or "", + type_.second_precision is not None + and "(%d)" % type_.second_precision + or "", + ) + + def visit_LONG(self, type_, **kw): + return "LONG" + + def visit_TIMESTAMP(self, type_, **kw): + if type_.timezone: + return "TIMESTAMP WITH TIME ZONE" + else: + return "TIMESTAMP" + + def visit_DOUBLE_PRECISION(self, type_, **kw): + return self._generate_numeric(type_, "DOUBLE PRECISION", **kw) + + def visit_BINARY_DOUBLE(self, type_, **kw): + return self._generate_numeric(type_, "BINARY_DOUBLE", **kw) + + def visit_BINARY_FLOAT(self, type_, **kw): + return self._generate_numeric(type_, "BINARY_FLOAT", **kw) + + def visit_FLOAT(self, type_, **kw): + # don't support conversion between decimal/binary + # precision yet + kw["no_precision"] = True + return self._generate_numeric(type_, "FLOAT", **kw) + + def visit_NUMBER(self, type_, **kw): + return self._generate_numeric(type_, "NUMBER", **kw) + + def _generate_numeric( + self, type_, name, precision=None, scale=None, no_precision=False, **kw + ): + if precision is None: + precision = type_.precision + + if scale is None: + scale = getattr(type_, "scale", None) + + if no_precision or precision is None: + return name + elif scale is None: + n = "%(name)s(%(precision)s)" + return n % {"name": name, "precision": precision} + else: + n = "%(name)s(%(precision)s, %(scale)s)" + return n % {"name": name, "precision": precision, "scale": scale} + + def visit_string(self, type_, **kw): + return self.visit_VARCHAR2(type_, **kw) + + def visit_VARCHAR2(self, type_, **kw): + return self._visit_varchar(type_, "", "2") + + def visit_NVARCHAR2(self, type_, **kw): + return self._visit_varchar(type_, "N", "2") + + visit_NVARCHAR = visit_NVARCHAR2 + + def visit_VARCHAR(self, type_, **kw): + return self._visit_varchar(type_, "", "") + + def _visit_varchar(self, type_, n, num): + if not type_.length: + return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n} + elif not n and self.dialect._supports_char_length: + varchar = "VARCHAR%(two)s(%(length)s CHAR)" + return varchar % {"length": type_.length, "two": num} + else: + varchar = "%(n)sVARCHAR%(two)s(%(length)s)" + return varchar % {"length": type_.length, "two": num, "n": n} + + def visit_text(self, type_, **kw): + return self.visit_CLOB(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + if self.dialect._use_nchar_for_unicode: + return self.visit_NCLOB(type_, **kw) + else: + return self.visit_CLOB(type_, **kw) + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) + + def visit_big_integer(self, type_, **kw): + return self.visit_NUMBER(type_, precision=19, **kw) + + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_RAW(self, type_, **kw): + if type_.length: + return "RAW(%(length)s)" % {"length": type_.length} + else: + return "RAW" + + def visit_ROWID(self, type_, **kw): + return "ROWID" + + +class OracleCompiler(compiler.SQLCompiler): + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + + compound_keywords = util.update_copy( + compiler.SQLCompiler.compound_keywords, + {expression.CompoundSelect.EXCEPT: "MINUS"}, + ) + + def __init__(self, *args, **kwargs): + self.__wheres = {} + super(OracleCompiler, self).__init__(*args, **kwargs) + + def visit_mod_binary(self, binary, operator, **kw): + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "LENGTH" + self.function_argspec(fn, **kw) + + def visit_match_op_binary(self, binary, operator, **kw): + return "CONTAINS (%s, %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def get_cte_preamble(self, recursive): + return "WITH" + + def get_select_hint_text(self, byfroms): + return " ".join("/*+ %s */" % text for table, text in byfroms.items()) + + def function_argspec(self, fn, **kw): + if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def visit_function(self, func, **kw): + text = super(OracleCompiler, self).visit_function(func, **kw) + if kw.get("asfrom", False): + text = "TABLE (%s)" % func + return text + + def visit_table_valued_column(self, element, **kw): + text = super(OracleCompiler, self).visit_table_valued_column( + element, **kw + ) + text = "COLUMN_VALUE " + text + return text + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, + and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + + return " FROM DUAL" + + def visit_join(self, join, from_linter=None, **kwargs): + if self.dialect.use_ansi: + return compiler.SQLCompiler.visit_join( + self, join, from_linter=from_linter, **kwargs + ) + else: + if from_linter: + from_linter.edges.add((join.left, join.right)) + + kwargs["asfrom"] = True + if isinstance(join.right, expression.FromGrouping): + right = join.right.element + else: + right = join.right + return ( + self.process(join.left, from_linter=from_linter, **kwargs) + + ", " + + self.process(right, from_linter=from_linter, **kwargs) + ) + + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + # https://docs.oracle.com/database/121/SQLRF/queries006.htm#SQLRF52354 + # "apply the outer join operator (+) to all columns of B in + # the join condition in the WHERE clause" - that is, + # unconditionally regardless of operator or the other side + def visit_binary(binary): + if isinstance( + binary.left, expression.ColumnClause + ) and join.right.is_derived_from(binary.left.table): + binary.left = _OuterJoinColumn(binary.left) + elif isinstance( + binary.right, expression.ColumnClause + ) and join.right.is_derived_from(binary.right.table): + binary.right = _OuterJoinColumn(binary.right) + + clauses.append( + visitors.cloned_traverse( + join.onclause, {}, {"binary": visit_binary} + ) + ) + else: + clauses.append(join.onclause) + + for j in join.left, join.right: + if isinstance(j, expression.Join): + visit_join(j) + elif isinstance(j, expression.FromGrouping): + visit_join(j.element) + + for f in froms: + if isinstance(f, expression.Join): + visit_join(f) + + if not clauses: + return None + else: + return sql.and_(*clauses) + + def visit_outer_join_column(self, vc, **kw): + return self.process(vc.column, **kw) + "(+)" + + def visit_sequence(self, seq, **kw): + return self.preparer.format_sequence(seq) + ".nextval" + + def get_render_as_alias_suffix(self, alias_name_text): + """Oracle doesn't like ``FROM table AS alias``""" + + return " " + alias_name_text + + def returning_clause(self, stmt, returning_cols): + columns = [] + binds = [] + + for i, column in enumerate( + expression._select_iterables(returning_cols) + ): + if ( + self.isupdate + and isinstance(column, sa_schema.Column) + and isinstance(column.server_default, Computed) + and not self.dialect._supports_update_returning_computed_cols + ): + util.warn( + "Computed columns don't work with Oracle UPDATE " + "statements that use RETURNING; the value of the column " + "*before* the UPDATE takes place is returned. It is " + "advised to not use RETURNING with an Oracle computed " + "column. Consider setting implicit_returning to False on " + "the Table object in order to avoid implicit RETURNING " + "clauses from being generated for this Table." + ) + if column.type._has_column_expression: + col_expr = column.type.column_expression(column) + else: + col_expr = column + + outparam = sql.outparam("ret_%d" % i, type_=column.type) + self.binds[outparam.key] = outparam + binds.append( + self.bindparam_string(self._truncate_bindparam(outparam)) + ) + + # ensure the ExecutionContext.get_out_parameters() method is + # *not* called; the cx_Oracle dialect wants to handle these + # parameters separately + self.has_out_parameters = False + + columns.append(self.process(col_expr, within_columns_clause=False)) + + self._add_to_result_map( + getattr(col_expr, "name", col_expr._anon_name_label), + getattr(col_expr, "name", col_expr._anon_name_label), + ( + column, + getattr(column, "name", None), + getattr(column, "key", None), + ), + column.type, + ) + + return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) + + def translate_select_structure(self, select_stmt, **kwargs): + select = select_stmt + + if not getattr(select, "_oracle_visit", None): + if not self.dialect.use_ansi: + froms = self._display_froms_for_select( + select, kwargs.get("asfrom", False) + ) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause is not None: + select = select.where(whereclause) + select._oracle_visit = True + + # if fetch is used this is not needed + if ( + select._has_row_limiting_clause + and select._fetch_clause is None + ): + limit_clause = select._limit_clause + offset_clause = select._offset_clause + + if select._simple_int_clause(limit_clause): + limit_clause = limit_clause.render_literal_execute() + + if select._simple_int_clause(offset_clause): + offset_clause = offset_clause.render_literal_execute() + + # currently using form at: + # https://blogs.oracle.com/oraclemagazine/\ + # on-rownum-and-limiting-results + + orig_select = select + select = select._generate() + select._oracle_visit = True + + # add expressions to accommodate FOR UPDATE OF + for_update = select._for_update_arg + if for_update is not None and for_update.of: + for_update = for_update._clone() + for_update._copy_internals() + + for elem in for_update.of: + if not select.selected_columns.contains_column(elem): + select = select.add_columns(elem) + + # Wrap the middle select and add the hint + inner_subquery = select.alias() + limitselect = sql.select( + *[ + c + for c in inner_subquery.c + if orig_select.selected_columns.corresponding_column(c) + is not None + ] + ) + + if ( + limit_clause is not None + and self.dialect.optimize_limits + and select._simple_int_clause(limit_clause) + ): + limitselect = limitselect.prefix_with( + expression.text( + "/*+ FIRST_ROWS(%s) */" + % self.process(limit_clause, **kwargs) + ) + ) + + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + # add expressions to accommodate FOR UPDATE OF + if for_update is not None and for_update.of: + + adapter = sql_util.ClauseAdapter(inner_subquery) + for_update.of = [ + adapter.traverse(elem) for elem in for_update.of + ] + + # If needed, add the limiting clause + if limit_clause is not None: + if select._simple_int_clause(limit_clause) and ( + offset_clause is None + or select._simple_int_clause(offset_clause) + ): + max_row = limit_clause + + if offset_clause is not None: + max_row = max_row + offset_clause + + else: + max_row = limit_clause + + if offset_clause is not None: + max_row = max_row + offset_clause + limitselect = limitselect.where( + sql.literal_column("ROWNUM") <= max_row + ) + + # If needed, add the ora_rn, and wrap again with offset. + if offset_clause is None: + limitselect._for_update_arg = for_update + select = limitselect + else: + limitselect = limitselect.add_columns( + sql.literal_column("ROWNUM").label("ora_rn") + ) + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + if for_update is not None and for_update.of: + limitselect_cols = limitselect.selected_columns + for elem in for_update.of: + if ( + limitselect_cols.corresponding_column(elem) + is None + ): + limitselect = limitselect.add_columns(elem) + + limit_subquery = limitselect.alias() + origselect_cols = orig_select.selected_columns + offsetselect = sql.select( + *[ + c + for c in limit_subquery.c + if origselect_cols.corresponding_column(c) + is not None + ] + ) + + offsetselect._oracle_visit = True + offsetselect._is_wrapper = True + + if for_update is not None and for_update.of: + adapter = sql_util.ClauseAdapter(limit_subquery) + for_update.of = [ + adapter.traverse(elem) for elem in for_update.of + ] + + offsetselect = offsetselect.where( + sql.literal_column("ora_rn") > offset_clause + ) + + offsetselect._for_update_arg = for_update + select = offsetselect + + return select + + def limit_clause(self, select, **kw): + return "" + + def visit_empty_set_expr(self, type_): + return "SELECT 1 FROM DUAL WHERE 1!=1" + + def for_update_clause(self, select, **kw): + if self.is_subquery(): + return "" + + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + tmp += " OF " + ", ".join( + self.process(elem, **kw) for elem in select._for_update_arg.of + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "DECODE(%s, %s, 0, 1) = 1" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "DECODE(%s, %s, 0, 1) = 0" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def _get_regexp_args(self, binary, kw): + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) + flags = binary.modifiers["flags"] + if flags is not None: + flags = self.process(flags, **kw) + return string, pattern, flags + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + string, pattern, flags = self._get_regexp_args(binary, kw) + if flags is None: + return "REGEXP_LIKE(%s, %s)" % (string, pattern) + else: + return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_regexp_match_op_binary( + binary, operator, **kw + ) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + string, pattern, flags = self._get_regexp_args(binary, kw) + replacement = self.process(binary.modifiers["replacement"], **kw) + if flags is None: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + string, + pattern, + replacement, + ) + else: + return "REGEXP_REPLACE(%s, %s, %s, %s)" % ( + string, + pattern, + replacement, + flags, + ) + + +class OracleDDLCompiler(compiler.DDLCompiler): + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + + # oracle has no ON UPDATE CASCADE - + # its only available via triggers + # https://asktom.oracle.com/tkyte/update_cascade/index.html + if constraint.onupdate is not None: + util.warn( + "Oracle does not contain native UPDATE CASCADE " + "functionality - onupdates will not be rendered for foreign " + "keys. Consider using deferrable=True, initially='deferred' " + "or triggers." + ) + + return text + + def visit_drop_table_comment(self, drop): + return "COMMENT ON TABLE %s IS ''" % self.preparer.format_table( + drop.element + ) + + def visit_create_index(self, create): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.dialect_options["oracle"]["bitmap"]: + text += "BITMAP " + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=True), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + if index.dialect_options["oracle"]["compress"] is not False: + if index.dialect_options["oracle"]["compress"] is True: + text += " COMPRESS" + else: + text += " COMPRESS %d" % ( + index.dialect_options["oracle"]["compress"] + ) + return text + + def post_create_table(self, table): + table_opts = [] + opts = table.dialect_options["oracle"] + + if opts["on_commit"]: + on_commit_options = opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) + + if opts["compress"]: + if opts["compress"] is True: + table_opts.append("\n COMPRESS") + else: + table_opts.append("\n COMPRESS FOR %s" % (opts["compress"])) + + return "".join(table_opts) + + def get_identity_options(self, identity_options): + text = super(OracleDDLCompiler, self).get_identity_options( + identity_options + ) + text = text.replace("NO MINVALUE", "NOMINVALUE") + text = text.replace("NO MAXVALUE", "NOMAXVALUE") + text = text.replace("NO CYCLE", "NOCYCLE") + text = text.replace("NO ORDER", "NOORDER") + return text + + def visit_computed_column(self, generated): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + raise exc.CompileError( + "Oracle computed columns do not support 'stored' persistence; " + "set the 'persisted' flag to None or False for Oracle support." + ) + elif generated.persisted is False: + text += " VIRTUAL" + return text + + def visit_identity_column(self, identity, **kw): + if identity.always is None: + kind = "" + else: + kind = "ALWAYS" if identity.always else "BY DEFAULT" + text = "GENERATED %s" % kind + if identity.on_null: + text += " ON NULL" + text += " AS IDENTITY" + options = self.get_identity_options(identity) + if options: + text += " (%s)" % options + return text + + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = {x.lower() for x in RESERVED_WORDS} + illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union( + ["_", "$"] + ) + + def _bindparam_requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + ) + + def format_savepoint(self, savepoint): + name = savepoint.ident.lstrip("_") + return super(OracleIdentifierPreparer, self).format_savepoint( + savepoint, name + ) + + +class OracleExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + return self._execute_scalar( + "SELECT " + + self.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", + type_, + ) + + +class OracleDialect(default.DefaultDialect): + name = "oracle" + supports_statement_cache = True + supports_alter = True + supports_unicode_statements = False + supports_unicode_binds = False + max_identifier_length = 128 + + supports_simple_order_by_label = False + cte_follows_insert = True + + supports_sequences = True + sequences_optional = False + postfetch_lastrowid = False + + default_paramstyle = "named" + colspecs = colspecs + ischema_names = ischema_names + requires_name_normalize = True + + supports_comments = True + + supports_default_values = False + supports_default_metavalue = True + supports_empty_insert = False + supports_identity_columns = True + + statement_compiler = OracleCompiler + ddl_compiler = OracleDDLCompiler + type_compiler = OracleTypeCompiler + preparer = OracleIdentifierPreparer + execution_ctx_cls = OracleExecutionContext + + reflection_options = ("oracle_resolve_synonyms",) + + _use_nchar_for_unicode = False + + construct_arguments = [ + ( + sa_schema.Table, + {"resolve_synonyms": False, "on_commit": None, "compress": False}, + ), + (sa_schema.Index, {"bitmap": False, "compress": False}), + ] + + @util.deprecated_params( + use_binds_for_limits=( + "1.4", + "The ``use_binds_for_limits`` Oracle dialect parameter is " + "deprecated. The dialect now renders LIMIT /OFFSET integers " + "inline in all cases using a post-compilation hook, so that the " + "value is still represented by a 'bound parameter' on the Core " + "Expression side.", + ) + ) + def __init__( + self, + use_ansi=True, + optimize_limits=False, + use_binds_for_limits=None, + use_nchar_for_unicode=False, + exclude_tablespaces=("SYSTEM", "SYSAUX"), + **kwargs + ): + default.DefaultDialect.__init__(self, **kwargs) + self._use_nchar_for_unicode = use_nchar_for_unicode + self.use_ansi = use_ansi + self.optimize_limits = optimize_limits + self.exclude_tablespaces = exclude_tablespaces + + def initialize(self, connection): + super(OracleDialect, self).initialize(connection) + + self.implicit_returning = self.__dict__.get( + "implicit_returning", self.server_version_info > (10,) + ) + + if self._is_oracle_8: + self.colspecs = self.colspecs.copy() + self.colspecs.pop(sqltypes.Interval) + self.use_ansi = False + + self.supports_identity_columns = self.server_version_info >= (12,) + + def _get_effective_compat_server_version_info(self, connection): + # dialect does not need compat levels below 12.2, so don't query + # in those cases + + if self.server_version_info < (12, 2): + return self.server_version_info + try: + compat = connection.exec_driver_sql( + "SELECT value FROM v$parameter WHERE name = 'compatible'" + ).scalar() + except exc.DBAPIError: + compat = None + + if compat: + try: + return tuple(int(x) for x in compat.split(".")) + except: + return self.server_version_info + else: + return self.server_version_info + + @property + def _is_oracle_8(self): + return self.server_version_info and self.server_version_info < (9,) + + @property + def _supports_table_compression(self): + return self.server_version_info and self.server_version_info >= (10, 1) + + @property + def _supports_table_compress_for(self): + return self.server_version_info and self.server_version_info >= (11,) + + @property + def _supports_char_length(self): + return not self._is_oracle_8 + + @property + def _supports_update_returning_computed_cols(self): + # on version 18 this error is no longet present while it happens on 11 + # it may work also on versions before the 18 + return self.server_version_info and self.server_version_info >= (18,) + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def _check_max_identifier_length(self, connection): + if self._get_effective_compat_server_version_info(connection) < ( + 12, + 2, + ): + return 30 + else: + # use the default + return None + + def _check_unicode_returns(self, connection): + additional_tests = [ + expression.cast( + expression.literal_column("'test nvarchar2 returns'"), + sqltypes.NVARCHAR(60), + ) + ] + return super(OracleDialect, self)._check_unicode_returns( + connection, additional_tests + ) + + _isolation_lookup = ["READ COMMITTED", "SERIALIZABLE"] + + def get_isolation_level(self, connection): + raise NotImplementedError("implemented by cx_Oracle dialect") + + def get_default_isolation_level(self, dbapi_conn): + try: + return self.get_isolation_level(dbapi_conn) + except NotImplementedError: + raise + except: + return "READ COMMITTED" + + def set_isolation_level(self, connection, level): + raise NotImplementedError("implemented by cx_Oracle dialect") + + def has_table(self, connection, table_name, schema=None): + self._ensure_has_table_connection(connection) + + if not schema: + schema = self.default_schema_name + + cursor = connection.execute( + sql.text( + "SELECT table_name FROM all_tables " + "WHERE table_name = CAST(:name AS VARCHAR2(128)) " + "AND owner = CAST(:schema_name AS VARCHAR2(128))" + ), + dict( + name=self.denormalize_name(table_name), + schema_name=self.denormalize_name(schema), + ), + ) + return cursor.first() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text( + "SELECT sequence_name FROM all_sequences " + "WHERE sequence_name = :name AND " + "sequence_owner = :schema_name" + ), + dict( + name=self.denormalize_name(sequence_name), + schema_name=self.denormalize_name(schema), + ), + ) + return cursor.first() is not None + + def _get_default_schema_name(self, connection): + return self.normalize_name( + connection.exec_driver_sql( + "select sys_context( 'userenv', 'current_schema' ) from dual" + ).scalar() + ) + + def _resolve_synonym( + self, + connection, + desired_owner=None, + desired_synonym=None, + desired_table=None, + ): + """search for a local synonym matching the given desired owner/name. + + if desired_owner is None, attempts to locate a distinct owner. + + returns the actual name, owner, dblink name, and synonym name if + found. + """ + + q = ( + "SELECT owner, table_owner, table_name, db_link, " + "synonym_name FROM all_synonyms WHERE " + ) + clauses = [] + params = {} + if desired_synonym: + clauses.append( + "synonym_name = CAST(:synonym_name AS VARCHAR2(128))" + ) + params["synonym_name"] = desired_synonym + if desired_owner: + clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))") + params["desired_owner"] = desired_owner + if desired_table: + clauses.append("table_name = CAST(:tname AS VARCHAR2(128))") + params["tname"] = desired_table + + q += " AND ".join(clauses) + + result = connection.execution_options(future_result=True).execute( + sql.text(q), params + ) + if desired_owner: + row = result.mappings().first() + if row: + return ( + row["table_name"], + row["table_owner"], + row["db_link"], + row["synonym_name"], + ) + else: + return None, None, None, None + else: + rows = result.mappings().all() + if len(rows) > 1: + raise AssertionError( + "There are multiple tables visible to the schema, you " + "must specify owner" + ) + elif len(rows) == 1: + row = rows[0] + return ( + row["table_name"], + row["table_owner"], + row["db_link"], + row["synonym_name"], + ) + else: + return None, None, None, None + + @reflection.cache + def _prepare_reflection_args( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + if resolve_synonyms: + actual_name, owner, dblink, synonym = self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(schema), + desired_synonym=self.denormalize_name(table_name), + ) + else: + actual_name, owner, dblink, synonym = None, None, None, None + if not actual_name: + actual_name = self.denormalize_name(table_name) + + if dblink: + # using user_db_links here since all_db_links appears + # to have more restricted permissions. + # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm + # will need to hear from more users if we are doing + # the right thing here. See [ticket:2619] + owner = connection.scalar( + sql.text( + "SELECT username FROM user_db_links " "WHERE db_link=:link" + ), + dict(link=dblink), + ) + dblink = "@" + dblink + elif not owner: + owner = self.denormalize_name(schema or self.default_schema_name) + + return (actual_name, owner, dblink or "", synonym) + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = "SELECT username FROM all_users ORDER BY username" + cursor = connection.exec_driver_sql(s) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.default_schema_name) + + # note that table_names() isn't loading DBLINKed or synonym'ed tables + if schema is None: + schema = self.default_schema_name + + sql_str = "SELECT table_name FROM all_tables WHERE " + if self.exclude_tablespaces: + sql_str += ( + "nvl(tablespace_name, 'no tablespace') " + "NOT IN (%s) AND " + % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) + ) + sql_str += ( + "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL" + ) + + cursor = connection.execute(sql.text(sql_str), dict(owner=schema)) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + schema = self.denormalize_name(self.default_schema_name) + + sql_str = "SELECT table_name FROM all_tables WHERE " + if self.exclude_tablespaces: + sql_str += ( + "nvl(tablespace_name, 'no tablespace') " + "NOT IN (%s) AND " + % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) + ) + sql_str += ( + "OWNER = :owner " + "AND IOT_NAME IS NULL " + "AND DURATION IS NOT NULL" + ) + + cursor = connection.execute(sql.text(sql_str), dict(owner=schema)) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + schema = self.denormalize_name(schema or self.default_schema_name) + s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") + cursor = connection.execute( + s, dict(owner=self.denormalize_name(schema)) + ) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text( + "SELECT sequence_name FROM all_sequences " + "WHERE sequence_owner = :schema_name" + ), + dict(schema_name=self.denormalize_name(schema)), + ) + return [self.normalize_name(row[0]) for row in cursor] + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + options = {} + + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + params = {"table_name": table_name} + + columns = ["table_name"] + if self._supports_table_compression: + columns.append("compression") + if self._supports_table_compress_for: + columns.append("compress_for") + + text = ( + "SELECT %(columns)s " + "FROM ALL_TABLES%(dblink)s " + "WHERE table_name = CAST(:table_name AS VARCHAR(128))" + ) + + if schema is not None: + params["owner"] = schema + text += " AND owner = CAST(:owner AS VARCHAR(128)) " + text = text % {"dblink": dblink, "columns": ", ".join(columns)} + + result = connection.execute(sql.text(text), params) + + enabled = dict(DISABLED=False, ENABLED=True) + + row = result.first() + if row: + if "compression" in row._fields and enabled.get( + row.compression, False + ): + if "compress_for" in row._fields: + options["oracle_compress"] = row.compress_for + else: + options["oracle_compress"] = True + + return options + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + columns = [] + if self._supports_char_length: + char_length_col = "char_length" + else: + char_length_col = "data_length" + + if self.server_version_info >= (12,): + identity_cols = """\ + col.default_on_null, + ( + SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS + FROM ALL_TAB_IDENTITY_COLS%(dblink)s id + WHERE col.table_name = id.table_name + AND col.column_name = id.column_name + AND col.owner = id.owner + ) AS identity_options""" % { + "dblink": dblink + } + else: + identity_cols = "NULL as default_on_null, NULL as identity_options" + + params = {"table_name": table_name} + + text = """ + SELECT + col.column_name, + col.data_type, + col.%(char_length_col)s, + col.data_precision, + col.data_scale, + col.nullable, + col.data_default, + com.comments, + col.virtual_column, + %(identity_cols)s + FROM all_tab_cols%(dblink)s col + LEFT JOIN all_col_comments%(dblink)s com + ON col.table_name = com.table_name + AND col.column_name = com.column_name + AND col.owner = com.owner + WHERE col.table_name = CAST(:table_name AS VARCHAR2(128)) + AND col.hidden_column = 'NO' + """ + if schema is not None: + params["owner"] = schema + text += " AND col.owner = :owner " + text += " ORDER BY col.column_id" + text = text % { + "dblink": dblink, + "char_length_col": char_length_col, + "identity_cols": identity_cols, + } + + c = connection.execute(sql.text(text), params) + + for row in c: + colname = self.normalize_name(row[0]) + orig_colname = row[0] + coltype = row[1] + length = row[2] + precision = row[3] + scale = row[4] + nullable = row[5] == "Y" + default = row[6] + comment = row[7] + generated = row[8] + default_on_nul = row[9] + identity_options = row[10] + + if coltype == "NUMBER": + if precision is None and scale == 0: + coltype = INTEGER() + else: + coltype = NUMBER(precision, scale) + elif coltype == "FLOAT": + # TODO: support "precision" here as "binary_precision" + coltype = FLOAT() + elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"): + coltype = self.ischema_names.get(coltype)(length) + elif "WITH TIME ZONE" in coltype: + coltype = TIMESTAMP(timezone=True) + else: + coltype = re.sub(r"\(\d+\)", "", coltype) + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (coltype, colname) + ) + coltype = sqltypes.NULLTYPE + + if generated == "YES": + computed = dict(sqltext=default) + default = None + else: + computed = None + + if identity_options is not None: + identity = self._parse_identity_options( + identity_options, default_on_nul + ) + default = None + else: + identity = None + + cdict = { + "name": colname, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": "auto", + "comment": comment, + } + if orig_colname.lower() == orig_colname: + cdict["quote"] = True + if computed is not None: + cdict["computed"] = computed + if identity is not None: + cdict["identity"] = identity + + columns.append(cdict) + return columns + + def _parse_identity_options(self, identity_options, default_on_nul): + # identity_options is a string that starts with 'ALWAYS,' or + # 'BY DEFAULT,' and continues with + # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1, + # CYCLE_FLAG: N, CACHE_SIZE: 1, ORDER_FLAG: N, SCALE_FLAG: N, + # EXTEND_FLAG: N, SESSION_FLAG: N, KEEP_VALUE: N + parts = [p.strip() for p in identity_options.split(",")] + identity = { + "always": parts[0] == "ALWAYS", + "on_null": default_on_nul == "YES", + } + + for part in parts[1:]: + option, value = part.split(":") + value = value.strip() + + if "START WITH" in option: + identity["start"] = compat.long_type(value) + elif "INCREMENT BY" in option: + identity["increment"] = compat.long_type(value) + elif "MAX_VALUE" in option: + identity["maxvalue"] = compat.long_type(value) + elif "MIN_VALUE" in option: + identity["minvalue"] = compat.long_type(value) + elif "CYCLE_FLAG" in option: + identity["cycle"] = value == "Y" + elif "CACHE_SIZE" in option: + identity["cache"] = compat.long_type(value) + elif "ORDER_FLAG" in option: + identity["order"] = value == "Y" + return identity + + @reflection.cache + def get_table_comment( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + info_cache = kw.get("info_cache") + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + if not schema: + schema = self.default_schema_name + + COMMENT_SQL = """ + SELECT comments + FROM all_tab_comments + WHERE table_name = CAST(:table_name AS VARCHAR(128)) + AND owner = CAST(:schema_name AS VARCHAR(128)) + """ + + c = connection.execute( + sql.text(COMMENT_SQL), + dict(table_name=table_name, schema_name=schema), + ) + return {"text": c.scalar()} + + @reflection.cache + def get_indexes( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + info_cache = kw.get("info_cache") + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + indexes = [] + + params = {"table_name": table_name} + text = ( + "SELECT a.index_name, a.column_name, " + "\nb.index_type, b.uniqueness, b.compression, b.prefix_length " + "\nFROM ALL_IND_COLUMNS%(dblink)s a, " + "\nALL_INDEXES%(dblink)s b " + "\nWHERE " + "\na.index_name = b.index_name " + "\nAND a.table_owner = b.table_owner " + "\nAND a.table_name = b.table_name " + "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))" + ) + + if schema is not None: + params["schema"] = schema + text += "AND a.table_owner = :schema " + + text += "ORDER BY a.index_name, a.column_position" + + text = text % {"dblink": dblink} + + q = sql.text(text) + rp = connection.execute(q, params) + indexes = [] + last_index_name = None + pk_constraint = self.get_pk_constraint( + connection, + table_name, + schema, + resolve_synonyms=resolve_synonyms, + dblink=dblink, + info_cache=kw.get("info_cache"), + ) + + uniqueness = dict(NONUNIQUE=False, UNIQUE=True) + enabled = dict(DISABLED=False, ENABLED=True) + + oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) + + index = None + for rset in rp: + index_name_normalized = self.normalize_name(rset.index_name) + + # skip primary key index. This is refined as of + # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y" + # if the name of this index was generated by Oracle, however + # if a named primary key constraint was created then this flag + # is false. + if ( + pk_constraint + and index_name_normalized == pk_constraint["name"] + ): + continue + + if rset.index_name != last_index_name: + index = dict( + name=index_name_normalized, + column_names=[], + dialect_options={}, + ) + indexes.append(index) + index["unique"] = uniqueness.get(rset.uniqueness, False) + + if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"): + index["dialect_options"]["oracle_bitmap"] = True + if enabled.get(rset.compression, False): + index["dialect_options"][ + "oracle_compress" + ] = rset.prefix_length + + # filter out Oracle SYS_NC names. could also do an outer join + # to the all_tab_columns table and check for real col names there. + if not oracle_sys_col.match(rset.column_name): + index["column_names"].append( + self.normalize_name(rset.column_name) + ) + last_index_name = rset.index_name + + return indexes + + @reflection.cache + def _get_constraint_data( + self, connection, table_name, schema=None, dblink="", **kw + ): + + params = {"table_name": table_name} + + text = ( + "SELECT" + "\nac.constraint_name," # 0 + "\nac.constraint_type," # 1 + "\nloc.column_name AS local_column," # 2 + "\nrem.table_name AS remote_table," # 3 + "\nrem.column_name AS remote_column," # 4 + "\nrem.owner AS remote_owner," # 5 + "\nloc.position as loc_pos," # 6 + "\nrem.position as rem_pos," # 7 + "\nac.search_condition," # 8 + "\nac.delete_rule" # 9 + "\nFROM all_constraints%(dblink)s ac," + "\nall_cons_columns%(dblink)s loc," + "\nall_cons_columns%(dblink)s rem" + "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))" + "\nAND ac.constraint_type IN ('R','P', 'U', 'C')" + ) + + if schema is not None: + params["owner"] = schema + text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))" + + text += ( + "\nAND ac.owner = loc.owner" + "\nAND ac.constraint_name = loc.constraint_name" + "\nAND ac.r_owner = rem.owner(+)" + "\nAND ac.r_constraint_name = rem.constraint_name(+)" + "\nAND (rem.position IS NULL or loc.position=rem.position)" + "\nORDER BY ac.constraint_name, loc.position" + ) + + text = text % {"dblink": dblink} + rp = connection.execute(sql.text(text), params) + constraint_data = rp.fetchall() + return constraint_data + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + pkeys = [] + constraint_name = None + constraint_data = self._get_constraint_data( + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) + + for row in constraint_data: + ( + cons_name, + cons_type, + local_column, + remote_table, + remote_column, + remote_owner, + ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + if cons_type == "P": + if constraint_name is None: + constraint_name = self.normalize_name(cons_name) + pkeys.append(local_column) + return {"constrained_columns": pkeys, "name": constraint_name} + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + """ + + kw arguments can be: + + oracle_resolve_synonyms + + dblink + + """ + requested_schema = schema # to check later on + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + constraint_data = self._get_constraint_data( + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) + + def fkey_rec(): + return { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + "options": {}, + } + + fkeys = util.defaultdict(fkey_rec) + + for row in constraint_data: + ( + cons_name, + cons_type, + local_column, + remote_table, + remote_column, + remote_owner, + ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + + cons_name = self.normalize_name(cons_name) + + if cons_type == "R": + if remote_table is None: + # ticket 363 + util.warn( + ( + "Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?" + ) + % {"dblink": dblink} + ) + continue + + rec = fkeys[cons_name] + rec["name"] = cons_name + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) + + if not rec["referred_table"]: + if resolve_synonyms: + ( + ref_remote_name, + ref_remote_owner, + ref_dblink, + ref_synonym, + ) = self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table), + ) + if ref_synonym: + remote_table = self.normalize_name(ref_synonym) + remote_owner = self.normalize_name( + ref_remote_owner + ) + + rec["referred_table"] = remote_table + + if ( + requested_schema is not None + or self.denormalize_name(remote_owner) != schema + ): + rec["referred_schema"] = remote_owner + + if row[9] != "NO ACTION": + rec["options"]["ondelete"] = row[9] + + local_cols.append(local_column) + remote_cols.append(remote_column) + + return list(fkeys.values()) + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + constraint_data = self._get_constraint_data( + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) + + unique_keys = filter(lambda x: x[1] == "U", constraint_data) + uniques_group = groupby(unique_keys, lambda x: x[0]) + + index_names = { + ix["name"] + for ix in self.get_indexes(connection, table_name, schema=schema) + } + return [ + { + "name": name, + "column_names": cols, + "duplicates_index": name if name in index_names else None, + } + for name, cols in [ + [ + self.normalize_name(i[0]), + [self.normalize_name(x[2]) for x in i[1]], + ] + for i in uniques_group + ] + ] + + @reflection.cache + def get_view_definition( + self, + connection, + view_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + info_cache = kw.get("info_cache") + (view_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + view_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + params = {"view_name": view_name} + text = "SELECT text FROM all_views WHERE view_name=:view_name" + + if schema is not None: + text += " AND owner = :schema" + params["schema"] = schema + + rp = connection.execute(sql.text(text), params).scalar() + if rp: + if util.py2k: + rp = rp.decode(self.encoding) + return rp + else: + return None + + @reflection.cache + def get_check_constraints( + self, connection, table_name, schema=None, include_all=False, **kw + ): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + constraint_data = self._get_constraint_data( + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) + + check_constraints = filter(lambda x: x[1] == "C", constraint_data) + + return [ + {"name": self.normalize_name(cons[0]), "sqltext": cons[8]} + for cons in check_constraints + if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8]) + ] + + +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = "outer_join_column" + + def __init__(self, column): + self.column = column diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py new file mode 100644 index 0000000..64029a4 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -0,0 +1,1424 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" +.. dialect:: oracle+cx_oracle + :name: cx-Oracle + :dbapi: cx_oracle + :connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]] + :url: https://oracle.github.io/python-cx_Oracle/ + +DSN vs. Hostname connections +----------------------------- + +cx_Oracle provides several methods of indicating the target database. The +dialect translates from a series of different URL forms. + +Hostname Connections with Easy Connect Syntax +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Given a hostname, port and service name of the target Oracle Database, for +example from Oracle's `Easy Connect syntax +<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#easy-connect-syntax-for-connection-strings>`_, +then connect in SQLAlchemy using the ``service_name`` query string parameter:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8") + +The `full Easy Connect syntax +<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-B0437826-43C1-49EC-A94D-B650B6A4A6EE>`_ +is not supported. Instead, use a ``tnsnames.ora`` file and connect using a +DSN. + +Connections with tnsnames.ora or Oracle Cloud +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Alternatively, if no port, database name, or ``service_name`` is provided, the +dialect will use an Oracle DSN "connection string". This takes the "hostname" +portion of the URL as the data source name. For example, if the +``tnsnames.ora`` file contains a `Net Service Name +<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#net-service-names-for-connection-strings>`_ +of ``myalias`` as below:: + + myalias = + (DESCRIPTION = + (ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521)) + (CONNECT_DATA = + (SERVER = DEDICATED) + (SERVICE_NAME = orclpdb1) + ) + ) + +The cx_Oracle dialect connects to this database service when ``myalias`` is the +hostname portion of the URL, without specifying a port, database name or +``service_name``:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8") + +Users of Oracle Cloud should use this syntax and also configure the cloud +wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases +<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#connecting-to-autononmous-databases>`_. + +SID Connections +^^^^^^^^^^^^^^^ + +To use Oracle's obsolete SID connection syntax, the SID can be passed in a +"database name" portion of the URL as below:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8") + +Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as +follows:: + + >>> import cx_Oracle + >>> cx_Oracle.makedsn("hostname", 1521, sid="dbname") + '(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))' + +Passing cx_Oracle connect arguments +----------------------------------- + +Additional connection arguments can usually be passed via the URL +query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted +and converted to the correct symbol:: + + e = create_engine( + "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true") + +.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names + within the URL string itself, to be passed to the cx_Oracle DBAPI. As + was the case earlier but not correctly documented, the + :paramref:`_sa.create_engine.connect_args` parameter also accepts all + cx_Oracle DBAPI connect arguments. + +To pass arguments directly to ``.connect()`` without using the query +string, use the :paramref:`_sa.create_engine.connect_args` dictionary. +Any cx_Oracle parameter value and/or constant may be passed, such as:: + + import cx_Oracle + e = create_engine( + "oracle+cx_oracle://user:pass@dsn", + connect_args={ + "encoding": "UTF-8", + "nencoding": "UTF-8", + "mode": cx_Oracle.SYSDBA, + "events": True + } + ) + +Note that the default value for ``encoding`` and ``nencoding`` was changed to +"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that +version, or later. + +Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver +-------------------------------------------------------------------------- + +There are also options that are consumed by the SQLAlchemy cx_oracle dialect +itself. These options are always passed directly to :func:`_sa.create_engine` +, such as:: + + e = create_engine( + "oracle+cx_oracle://user:pass@dsn", coerce_to_unicode=False) + +The parameters accepted by the cx_oracle dialect are as follows: + +* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted + to 50. This setting is significant with cx_Oracle as the contents of LOB + objects are only readable within a "live" row (e.g. within a batch of + 50 rows). + +* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`. + +* ``coerce_to_unicode`` - see :ref:`cx_oracle_unicode` for detail. + +* ``coerce_to_decimal`` - see :ref:`cx_oracle_numeric` for detail. + +* ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail. + +.. _cx_oracle_sessionpool: + +Using cx_Oracle SessionPool +--------------------------- + +The cx_Oracle library provides its own connection pool implementation that may +be used in place of SQLAlchemy's pooling functionality. This can be achieved +by using the :paramref:`_sa.create_engine.creator` parameter to provide a +function that returns a new connection, along with setting +:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable +SQLAlchemy's pooling:: + + import cx_Oracle + from sqlalchemy import create_engine + from sqlalchemy.pool import NullPool + + pool = cx_Oracle.SessionPool( + user="scott", password="tiger", dsn="orclpdb", + min=2, max=5, increment=1, threaded=True, + encoding="UTF-8", nencoding="UTF-8" + ) + + engine = create_engine("oracle://", creator=pool.acquire, poolclass=NullPool) + +The above engine may then be used normally where cx_Oracle's pool handles +connection pooling:: + + with engine.connect() as conn: + print(conn.scalar("select 1 FROM dual")) + + +As well as providing a scalable solution for multi-user applications, the +cx_Oracle session pool supports some Oracle features such as DRCP and +`Application Continuity +<https://cx-oracle.readthedocs.io/en/latest/user_guide/ha.html#application-continuity-ac>`_. + +Using Oracle Database Resident Connection Pooling (DRCP) +-------------------------------------------------------- + +When using Oracle's `DRCP +<https://www.oracle.com/pls/topic/lookup?ctx=dblatest&id=GUID-015CA8C1-2386-4626-855D-CC546DDC1086>`_, +the best practice is to pass a connection class and "purity" when acquiring a +connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation +<https://cx-oracle.readthedocs.io/en/latest/user_guide/connection_handling.html#database-resident-connection-pooling-drcp>`_. + +This can be achieved by wrapping ``pool.acquire()``:: + + import cx_Oracle + from sqlalchemy import create_engine + from sqlalchemy.pool import NullPool + + pool = cx_Oracle.SessionPool( + user="scott", password="tiger", dsn="orclpdb", + min=2, max=5, increment=1, threaded=True, + encoding="UTF-8", nencoding="UTF-8" + ) + + def creator(): + return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF) + + engine = create_engine("oracle://", creator=creator, poolclass=NullPool) + +The above engine may then be used normally where cx_Oracle handles session +pooling and Oracle Database additionally uses DRCP:: + + with engine.connect() as conn: + print(conn.scalar("select 1 FROM dual")) + +.. _cx_oracle_unicode: + +Unicode +------- + +As is the case for all DBAPIs under Python 3, all strings are inherently +Unicode strings. Under Python 2, cx_Oracle also supports Python Unicode +objects directly. In all cases however, the driver requires an explicit +encoding configuration. + +Ensuring the Correct Client Encoding +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The long accepted standard for establishing client encoding for nearly all +Oracle related software is via the `NLS_LANG <https://www.oracle.com/database/technologies/faq-nls-lang.html>`_ +environment variable. cx_Oracle like most other Oracle drivers will use +this environment variable as the source of its encoding configuration. The +format of this variable is idiosyncratic; a typical value would be +``AMERICAN_AMERICA.AL32UTF8``. + +The cx_Oracle driver also supports a programmatic alternative which is to +pass the ``encoding`` and ``nencoding`` parameters directly to its +``.connect()`` function. These can be present in the URL as follows:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8") + +For the meaning of the ``encoding`` and ``nencoding`` parameters, please +consult +`Characters Sets and National Language Support (NLS) <https://cx-oracle.readthedocs.io/en/latest/user_guide/globalization.html#globalization>`_. + +.. seealso:: + + `Characters Sets and National Language Support (NLS) <https://cx-oracle.readthedocs.io/en/latest/user_guide/globalization.html#globalization>`_ + - in the cx_Oracle documentation. + + +Unicode-specific Column datatypes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The Core expression language handles unicode data by use of the :class:`.Unicode` +and :class:`.UnicodeText` +datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by +default. When using these datatypes with Unicode data, it is expected that +the Oracle database is configured with a Unicode-aware character set, as well +as that the ``NLS_LANG`` environment variable is set appropriately, so that +the VARCHAR2 and CLOB datatypes can accommodate the data. + +In the case that the Oracle database is not configured with a Unicode character +set, the two options are to use the :class:`_types.NCHAR` and +:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag +``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, +which will cause the +SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / +:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB. + +.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` + datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes + unless the ``use_nchar_for_unicode=True`` is passed to the dialect + when :func:`_sa.create_engine` is called. + +Unicode Coercion of result rows under Python 2 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When result sets are fetched that include strings, under Python 3 the cx_Oracle +DBAPI returns all strings as Python Unicode objects, since Python 3 only has a +Unicode string type. This occurs for data fetched from datatypes such as +VARCHAR2, CHAR, CLOB, NCHAR, NCLOB, etc. In order to provide cross- +compatibility under Python 2, the SQLAlchemy cx_Oracle dialect will add +Unicode-conversion to string data under Python 2 as well. Historically, this +made use of converters that were supplied by cx_Oracle but were found to be +non-performant; SQLAlchemy's own converters are used for the string to Unicode +conversion under Python 2. To disable the Python 2 Unicode conversion for +VARCHAR2, CHAR, and CLOB, the flag ``coerce_to_unicode=False`` can be passed to +:func:`_sa.create_engine`. + +.. versionchanged:: 1.3 Unicode conversion is applied to all string values + by default under python 2. The ``coerce_to_unicode`` now defaults to True + and can be set to False to disable the Unicode coercion of strings that are + delivered as VARCHAR2/CHAR/CLOB data. + +.. _cx_oracle_unicode_encoding_errors: + +Encoding Errors +^^^^^^^^^^^^^^^ + +For the unusual case that data in the Oracle database is present with a broken +encoding, the dialect accepts a parameter ``encoding_errors`` which will be +passed to Unicode decoding functions in order to affect how decoding errors are +handled. The value is ultimately consumed by the Python `decode +<https://docs.python.org/3/library/stdtypes.html#bytes.decode>`_ function, and +is passed both via cx_Oracle's ``encodingErrors`` parameter consumed by +``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the +cx_Oracle dialect makes use of both under different circumstances. + +.. versionadded:: 1.3.11 + + +.. _cx_oracle_setinputsizes: + +Fine grained control over cx_Oracle data binding performance with setinputsizes +------------------------------------------------------------------------------- + +The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the +DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the +datatypes that are bound to a SQL statement for Python values being passed as +parameters. While virtually no other DBAPI assigns any use to the +``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its +interactions with the Oracle client interface, and in some scenarios it is not +possible for SQLAlchemy to know exactly how data should be bound, as some +settings can cause profoundly different performance characteristics, while +altering the type coercion behavior at the same time. + +Users of the cx_Oracle dialect are **strongly encouraged** to read through +cx_Oracle's list of built-in datatype symbols at +https://cx-oracle.readthedocs.io/en/latest/api_manual/module.html#database-types. +Note that in some cases, significant performance degradation can occur when +using these types vs. not, in particular when specifying ``cx_Oracle.CLOB``. + +On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can +be used both for runtime visibility (e.g. logging) of the setinputsizes step as +well as to fully control how ``setinputsizes()`` is used on a per-statement +basis. + +.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes` + + +Example 1 - logging all setinputsizes calls +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following example illustrates how to log the intermediary values from a +SQLAlchemy perspective before they are converted to the raw ``setinputsizes()`` +parameter dictionary. The keys of the dictionary are :class:`.BindParameter` +objects which have a ``.key`` and a ``.type`` attribute:: + + from sqlalchemy import create_engine, event + + engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") + + @event.listens_for(engine, "do_setinputsizes") + def _log_setinputsizes(inputsizes, cursor, statement, parameters, context): + for bindparam, dbapitype in inputsizes.items(): + log.info( + "Bound parameter name: %s SQLAlchemy type: %r " + "DBAPI object: %s", + bindparam.key, bindparam.type, dbapitype) + +Example 2 - remove all bindings to CLOB +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``CLOB`` datatype in cx_Oracle incurs a significant performance overhead, +however is set by default for the ``Text`` type within the SQLAlchemy 1.2 +series. This setting can be modified as follows:: + + from sqlalchemy import create_engine, event + from cx_Oracle import CLOB + + engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") + + @event.listens_for(engine, "do_setinputsizes") + def _remove_clob(inputsizes, cursor, statement, parameters, context): + for bindparam, dbapitype in list(inputsizes.items()): + if dbapitype is CLOB: + del inputsizes[bindparam] + +.. _cx_oracle_returning: + +RETURNING Support +----------------- + +The cx_Oracle dialect implements RETURNING using OUT parameters. +The dialect supports RETURNING fully, however cx_Oracle 6 is recommended +for complete support. + +.. _cx_oracle_lob: + +LOB Objects +----------- + +cx_oracle returns oracle LOBs using the cx_oracle.LOB object. SQLAlchemy +converts these to strings so that the interface of the Binary type is +consistent with that of other backends, which takes place within a cx_Oracle +outputtypehandler. + +cx_Oracle prior to version 6 would require that LOB objects be read before +a new batch of rows would be read, as determined by the ``cursor.arraysize``. +As of the 6 series, this limitation has been lifted. Nevertheless, because +SQLAlchemy pre-reads these LOBs up front, this issue is avoided in any case. + +To disable the auto "read()" feature of the dialect, the flag +``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`. Under +the cx_Oracle 5 series, having this flag turned off means there is the chance +of reading from a stale LOB object if not read as it is fetched. With +cx_Oracle 6, this issue is resolved. + +.. versionchanged:: 1.2 the LOB handling system has been greatly simplified + internally to make use of outputtypehandlers, and no longer makes use + of alternate "buffered" result set objects. + +Two Phase Transactions Not Supported +------------------------------------- + +Two phase transactions are **not supported** under cx_Oracle due to poor +driver support. As of cx_Oracle 6.0b1, the interface for +two phase transactions has been changed to be more of a direct pass-through +to the underlying OCI layer with less automation. The additional logic +to support this system is not implemented in SQLAlchemy. + +.. _cx_oracle_numeric: + +Precision Numerics +------------------ + +SQLAlchemy's numeric types can handle receiving and returning values as Python +``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a +subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in +use, the :paramref:`.Numeric.asdecimal` flag determines if values should be +coerced to ``Decimal`` upon return, or returned as float objects. To make +matters more complicated under Oracle, Oracle's ``NUMBER`` type can also +represent integer values if the "scale" is zero, so the Oracle-specific +:class:`_oracle.NUMBER` type takes this into account as well. + +The cx_Oracle dialect makes extensive use of connection- and cursor-level +"outputtypehandler" callables in order to coerce numeric values as requested. +These callables are specific to the specific flavor of :class:`.Numeric` in +use, as well as if no SQLAlchemy typing objects are present. There are +observed scenarios where Oracle may sends incomplete or ambiguous information +about the numeric types being returned, such as a query where the numeric types +are buried under multiple levels of subquery. The type handlers do their best +to make the right decision in all cases, deferring to the underlying cx_Oracle +DBAPI for all those cases where the driver can make the best decision. + +When no typing objects are present, as when executing plain SQL strings, a +default "outputtypehandler" is present which will generally return numeric +values which specify precision and scale as Python ``Decimal`` objects. To +disable this coercion to decimal for performance reasons, pass the flag +``coerce_to_decimal=False`` to :func:`_sa.create_engine`:: + + engine = create_engine("oracle+cx_oracle://dsn", coerce_to_decimal=False) + +The ``coerce_to_decimal`` flag only impacts the results of plain string +SQL statements that are not otherwise associated with a :class:`.Numeric` +SQLAlchemy type (or a subclass of such). + +.. versionchanged:: 1.2 The numeric handling system for cx_Oracle has been + reworked to take advantage of newer cx_Oracle features as well + as better integration of outputtypehandlers. + +""" # noqa + +from __future__ import absolute_import + +import decimal +import random +import re + +from . import base as oracle +from .base import OracleCompiler +from .base import OracleDialect +from .base import OracleExecutionContext +from ... import exc +from ... import processors +from ... import types as sqltypes +from ... import util +from ...engine import cursor as _cursor +from ...util import compat + + +class _OracleInteger(sqltypes.Integer): + def get_dbapi_type(self, dbapi): + # see https://github.com/oracle/python-cx_Oracle/issues/ + # 208#issuecomment-409715955 + return int + + def _cx_oracle_var(self, dialect, cursor): + cx_Oracle = dialect.dbapi + return cursor.var( + cx_Oracle.STRING, 255, arraysize=cursor.arraysize, outconverter=int + ) + + def _cx_oracle_outputtypehandler(self, dialect): + def handler(cursor, name, default_type, size, precision, scale): + return self._cx_oracle_var(dialect, cursor) + + return handler + + +class _OracleNumeric(sqltypes.Numeric): + is_number = False + + def bind_processor(self, dialect): + if self.scale == 0: + return None + elif self.asdecimal: + processor = processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + + def process(value): + if isinstance(value, (int, float)): + return processor(value) + elif value is not None and value.is_infinite(): + return float(value) + else: + return value + + return process + else: + return processors.to_float + + def result_processor(self, dialect, coltype): + return None + + def _cx_oracle_outputtypehandler(self, dialect): + cx_Oracle = dialect.dbapi + + is_cx_oracle_6 = dialect._is_cx_oracle_6 + + def handler(cursor, name, default_type, size, precision, scale): + outconverter = None + + if precision: + if self.asdecimal: + if default_type == cx_Oracle.NATIVE_FLOAT: + # receiving float and doing Decimal after the fact + # allows for float("inf") to be handled + type_ = default_type + outconverter = decimal.Decimal + elif is_cx_oracle_6: + type_ = decimal.Decimal + else: + type_ = cx_Oracle.STRING + outconverter = dialect._to_decimal + else: + if self.is_number and scale == 0: + # integer. cx_Oracle is observed to handle the widest + # variety of ints when no directives are passed, + # from 5.2 to 7.0. See [ticket:4457] + return None + else: + type_ = cx_Oracle.NATIVE_FLOAT + + else: + if self.asdecimal: + if default_type == cx_Oracle.NATIVE_FLOAT: + type_ = default_type + outconverter = decimal.Decimal + elif is_cx_oracle_6: + type_ = decimal.Decimal + else: + type_ = cx_Oracle.STRING + outconverter = dialect._to_decimal + else: + if self.is_number and scale == 0: + # integer. cx_Oracle is observed to handle the widest + # variety of ints when no directives are passed, + # from 5.2 to 7.0. See [ticket:4457] + return None + else: + type_ = cx_Oracle.NATIVE_FLOAT + + return cursor.var( + type_, + 255, + arraysize=cursor.arraysize, + outconverter=outconverter, + ) + + return handler + + +class _OracleBinaryFloat(_OracleNumeric): + def get_dbapi_type(self, dbapi): + return dbapi.NATIVE_FLOAT + + +class _OracleBINARY_FLOAT(_OracleBinaryFloat, oracle.BINARY_FLOAT): + pass + + +class _OracleBINARY_DOUBLE(_OracleBinaryFloat, oracle.BINARY_DOUBLE): + pass + + +class _OracleNUMBER(_OracleNumeric): + is_number = True + + +class _OracleDate(sqltypes.Date): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return value.date() + else: + return value + + return process + + +# TODO: the names used across CHAR / VARCHAR / NCHAR / NVARCHAR +# here are inconsistent and not very good +class _OracleChar(sqltypes.CHAR): + def get_dbapi_type(self, dbapi): + return dbapi.FIXED_CHAR + + +class _OracleNChar(sqltypes.NCHAR): + def get_dbapi_type(self, dbapi): + return dbapi.FIXED_NCHAR + + +class _OracleUnicodeStringNCHAR(oracle.NVARCHAR2): + def get_dbapi_type(self, dbapi): + return dbapi.NCHAR + + +class _OracleUnicodeStringCHAR(sqltypes.Unicode): + def get_dbapi_type(self, dbapi): + return dbapi.LONG_STRING + + +class _OracleUnicodeTextNCLOB(oracle.NCLOB): + def get_dbapi_type(self, dbapi): + return dbapi.NCLOB + + +class _OracleUnicodeTextCLOB(sqltypes.UnicodeText): + def get_dbapi_type(self, dbapi): + return dbapi.CLOB + + +class _OracleText(sqltypes.Text): + def get_dbapi_type(self, dbapi): + return dbapi.CLOB + + +class _OracleLong(oracle.LONG): + def get_dbapi_type(self, dbapi): + return dbapi.LONG_STRING + + +class _OracleString(sqltypes.String): + pass + + +class _OracleEnum(sqltypes.Enum): + def bind_processor(self, dialect): + enum_proc = sqltypes.Enum.bind_processor(self, dialect) + + def process(value): + raw_str = enum_proc(value) + return raw_str + + return process + + +class _OracleBinary(sqltypes.LargeBinary): + def get_dbapi_type(self, dbapi): + return dbapi.BLOB + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if not dialect.auto_convert_lobs: + return None + else: + return super(_OracleBinary, self).result_processor( + dialect, coltype + ) + + +class _OracleInterval(oracle.INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + +class _OracleRaw(oracle.RAW): + pass + + +class _OracleRowid(oracle.ROWID): + def get_dbapi_type(self, dbapi): + return dbapi.ROWID + + +class OracleCompiler_cx_oracle(OracleCompiler): + _oracle_cx_sql_compiler = True + + def bindparam_string(self, name, **kw): + quote = getattr(name, "quote", None) + if ( + quote is True + or quote is not False + and self.preparer._bindparam_requires_quotes(name) + and not kw.get("post_compile", False) + ): + # interesting to note about expanding parameters - since the + # new parameters take the form <paramname>_<int>, at least if + # they are originally formed from reserved words, they no longer + # need quoting :). names that include illegal characters + # won't work however. + quoted_name = '"%s"' % name + kw["escaped_from"] = name + name = quoted_name + + return OracleCompiler.bindparam_string(self, name, **kw) + + +class OracleExecutionContext_cx_oracle(OracleExecutionContext): + out_parameters = None + + def _generate_out_parameter_vars(self): + # check for has_out_parameters or RETURNING, create cx_Oracle.var + # objects if so + if self.compiled.returning or self.compiled.has_out_parameters: + quoted_bind_names = self.compiled.escaped_bind_names + for bindparam in self.compiled.binds.values(): + if bindparam.isoutparam: + name = self.compiled.bind_names[bindparam] + type_impl = bindparam.type.dialect_impl(self.dialect) + + if hasattr(type_impl, "_cx_oracle_var"): + self.out_parameters[name] = type_impl._cx_oracle_var( + self.dialect, self.cursor + ) + else: + dbtype = type_impl.get_dbapi_type(self.dialect.dbapi) + + cx_Oracle = self.dialect.dbapi + + if dbtype is None: + raise exc.InvalidRequestError( + "Cannot create out parameter for " + "parameter " + "%r - its type %r is not supported by" + " cx_oracle" % (bindparam.key, bindparam.type) + ) + + if compat.py2k and dbtype in ( + cx_Oracle.CLOB, + cx_Oracle.NCLOB, + ): + outconverter = ( + processors.to_unicode_processor_factory( + self.dialect.encoding, + errors=self.dialect.encoding_errors, + ) + ) + self.out_parameters[name] = self.cursor.var( + dbtype, + outconverter=lambda value: outconverter( + value.read() + ), + ) + + elif dbtype in ( + cx_Oracle.BLOB, + cx_Oracle.CLOB, + cx_Oracle.NCLOB, + ): + self.out_parameters[name] = self.cursor.var( + dbtype, outconverter=lambda value: value.read() + ) + elif compat.py2k and isinstance( + type_impl, sqltypes.Unicode + ): + outconverter = ( + processors.to_unicode_processor_factory( + self.dialect.encoding, + errors=self.dialect.encoding_errors, + ) + ) + self.out_parameters[name] = self.cursor.var( + dbtype, outconverter=outconverter + ) + else: + self.out_parameters[name] = self.cursor.var(dbtype) + self.parameters[0][ + quoted_bind_names.get(name, name) + ] = self.out_parameters[name] + + def _generate_cursor_outputtype_handler(self): + output_handlers = {} + + for (keyname, name, objects, type_) in self.compiled._result_columns: + handler = type_._cached_custom_processor( + self.dialect, + "cx_oracle_outputtypehandler", + self._get_cx_oracle_type_handler, + ) + + if handler: + denormalized_name = self.dialect.denormalize_name(keyname) + output_handlers[denormalized_name] = handler + + if output_handlers: + default_handler = self._dbapi_connection.outputtypehandler + + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): + if name in output_handlers: + return output_handlers[name]( + cursor, name, default_type, size, precision, scale + ) + else: + return default_handler( + cursor, name, default_type, size, precision, scale + ) + + self.cursor.outputtypehandler = output_type_handler + + def _get_cx_oracle_type_handler(self, impl): + if hasattr(impl, "_cx_oracle_outputtypehandler"): + return impl._cx_oracle_outputtypehandler(self.dialect) + else: + return None + + def pre_exec(self): + if not getattr(self.compiled, "_oracle_cx_sql_compiler", False): + return + + self.out_parameters = {} + + self._generate_out_parameter_vars() + + self._generate_cursor_outputtype_handler() + + self.include_set_input_sizes = self.dialect._include_setinputsizes + + def post_exec(self): + if self.compiled and self.out_parameters and self.compiled.returning: + # create a fake cursor result from the out parameters. unlike + # get_out_parameter_values(), the result-row handlers here will be + # applied at the Result level + returning_params = [ + self.dialect._returningval(self.out_parameters["ret_%d" % i]) + for i in range(len(self.out_parameters)) + ] + + fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + [ + (getattr(col, "name", col._anon_name_label), None) + for col in self.compiled.returning + ], + initial_buffer=[tuple(returning_params)], + ) + + self.cursor_fetch_strategy = fetch_strategy + + def create_cursor(self): + c = self._dbapi_connection.cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + def get_out_parameter_values(self, out_param_names): + # this method should not be called when the compiler has + # RETURNING as we've turned the has_out_parameters flag set to + # False. + assert not self.compiled.returning + + return [ + self.dialect._paramval(self.out_parameters[name]) + for name in out_param_names + ] + + +class OracleDialect_cx_oracle(OracleDialect): + supports_statement_cache = True + execution_ctx_cls = OracleExecutionContext_cx_oracle + statement_compiler = OracleCompiler_cx_oracle + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_unicode_statements = True + supports_unicode_binds = True + + use_setinputsizes = True + + driver = "cx_oracle" + + colspecs = { + sqltypes.Numeric: _OracleNumeric, + sqltypes.Float: _OracleNumeric, + oracle.BINARY_FLOAT: _OracleBINARY_FLOAT, + oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE, + sqltypes.Integer: _OracleInteger, + oracle.NUMBER: _OracleNUMBER, + sqltypes.Date: _OracleDate, + sqltypes.LargeBinary: _OracleBinary, + sqltypes.Boolean: oracle._OracleBoolean, + sqltypes.Interval: _OracleInterval, + oracle.INTERVAL: _OracleInterval, + sqltypes.Text: _OracleText, + sqltypes.String: _OracleString, + sqltypes.UnicodeText: _OracleUnicodeTextCLOB, + sqltypes.CHAR: _OracleChar, + sqltypes.NCHAR: _OracleNChar, + sqltypes.Enum: _OracleEnum, + oracle.LONG: _OracleLong, + oracle.RAW: _OracleRaw, + sqltypes.Unicode: _OracleUnicodeStringCHAR, + sqltypes.NVARCHAR: _OracleUnicodeStringNCHAR, + oracle.NCLOB: _OracleUnicodeTextNCLOB, + oracle.ROWID: _OracleRowid, + } + + execute_sequence_format = list + + _cx_oracle_threaded = None + + @util.deprecated_params( + threaded=( + "1.3", + "The 'threaded' parameter to the cx_oracle dialect " + "is deprecated as a dialect-level argument, and will be removed " + "in a future release. As of version 1.3, it defaults to False " + "rather than True. The 'threaded' option can be passed to " + "cx_Oracle directly in the URL query string passed to " + ":func:`_sa.create_engine`.", + ) + ) + def __init__( + self, + auto_convert_lobs=True, + coerce_to_unicode=True, + coerce_to_decimal=True, + arraysize=50, + encoding_errors=None, + threaded=None, + **kwargs + ): + + OracleDialect.__init__(self, **kwargs) + self.arraysize = arraysize + self.encoding_errors = encoding_errors + if threaded is not None: + self._cx_oracle_threaded = threaded + self.auto_convert_lobs = auto_convert_lobs + self.coerce_to_unicode = coerce_to_unicode + self.coerce_to_decimal = coerce_to_decimal + if self._use_nchar_for_unicode: + self.colspecs = self.colspecs.copy() + self.colspecs[sqltypes.Unicode] = _OracleUnicodeStringNCHAR + self.colspecs[sqltypes.UnicodeText] = _OracleUnicodeTextNCLOB + + cx_Oracle = self.dbapi + + if cx_Oracle is None: + self._include_setinputsizes = {} + self.cx_oracle_ver = (0, 0, 0) + else: + self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version) + if self.cx_oracle_ver < (5, 2) and self.cx_oracle_ver > (0, 0, 0): + raise exc.InvalidRequestError( + "cx_Oracle version 5.2 and above are supported" + ) + + self._include_setinputsizes = { + cx_Oracle.DATETIME, + cx_Oracle.NCLOB, + cx_Oracle.CLOB, + cx_Oracle.LOB, + cx_Oracle.NCHAR, + cx_Oracle.FIXED_NCHAR, + cx_Oracle.BLOB, + cx_Oracle.FIXED_CHAR, + cx_Oracle.TIMESTAMP, + _OracleInteger, + _OracleBINARY_FLOAT, + _OracleBINARY_DOUBLE, + } + + self._paramval = lambda value: value.getvalue() + + # https://github.com/oracle/python-cx_Oracle/issues/176#issuecomment-386821291 + # https://github.com/oracle/python-cx_Oracle/issues/224 + self._values_are_lists = self.cx_oracle_ver >= (6, 3) + if self._values_are_lists: + cx_Oracle.__future__.dml_ret_array_val = True + + def _returningval(value): + try: + return value.values[0][0] + except IndexError: + return None + + self._returningval = _returningval + else: + self._returningval = self._paramval + + self._is_cx_oracle_6 = self.cx_oracle_ver >= (6,) + + @property + def _cursor_var_unicode_kwargs(self): + if self.encoding_errors: + if self.cx_oracle_ver >= (6, 4): + return {"encodingErrors": self.encoding_errors} + else: + util.warn( + "cx_oracle version %r does not support encodingErrors" + % (self.cx_oracle_ver,) + ) + + return {} + + def _parse_cx_oracle_ver(self, version): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) + if m: + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) + else: + return (0, 0, 0) + + @classmethod + def dbapi(cls): + import cx_Oracle + + return cx_Oracle + + def initialize(self, connection): + super(OracleDialect_cx_oracle, self).initialize(connection) + if self._is_oracle_8: + self.supports_unicode_binds = False + + self._detect_decimal_char(connection) + + def get_isolation_level(self, connection): + # sources: + + # general idea of transaction id, have to start one, etc. + # https://stackoverflow.com/questions/10711204/how-to-check-isoloation-level + + # how to decode xid cols from v$transaction to match + # https://asktom.oracle.com/pls/apex/f?p=100:11:0::::P11_QUESTION_ID:9532779900346079444 + + # Oracle tuple comparison without using IN: + # https://www.sql-workbench.eu/comparison/tuple_comparison.html + + with connection.cursor() as cursor: + # this is the only way to ensure a transaction is started without + # actually running DML. There's no way to see the configured + # isolation level without getting it from v$transaction which + # means transaction has to be started. + outval = cursor.var(str) + cursor.execute( + """ + begin + :trans_id := dbms_transaction.local_transaction_id( TRUE ); + end; + """, + {"trans_id": outval}, + ) + trans_id = outval.getvalue() + xidusn, xidslot, xidsqn = trans_id.split(".", 2) + + cursor.execute( + "SELECT CASE BITAND(t.flag, POWER(2, 28)) " + "WHEN 0 THEN 'READ COMMITTED' " + "ELSE 'SERIALIZABLE' END AS isolation_level " + "FROM v$transaction t WHERE " + "(t.xidusn, t.xidslot, t.xidsqn) = " + "((:xidusn, :xidslot, :xidsqn))", + {"xidusn": xidusn, "xidslot": xidslot, "xidsqn": xidsqn}, + ) + row = cursor.fetchone() + if row is None: + raise exc.InvalidRequestError( + "could not retrieve isolation level" + ) + result = row[0] + + return result + + def set_isolation_level(self, connection, level): + if hasattr(connection, "dbapi_connection"): + dbapi_connection = connection.dbapi_connection + else: + dbapi_connection = connection + if level == "AUTOCOMMIT": + dbapi_connection.autocommit = True + else: + dbapi_connection.autocommit = False + connection.rollback() + with connection.cursor() as cursor: + cursor.execute("ALTER SESSION SET ISOLATION_LEVEL=%s" % level) + + def _detect_decimal_char(self, connection): + # we have the option to change this setting upon connect, + # or just look at what it is upon connect and convert. + # to minimize the chance of interference with changes to + # NLS_TERRITORY or formatting behavior of the DB, we opt + # to just look at it + + self._decimal_char = connection.exec_driver_sql( + "select value from nls_session_parameters " + "where parameter = 'NLS_NUMERIC_CHARACTERS'" + ).scalar()[0] + if self._decimal_char != ".": + _detect_decimal = self._detect_decimal + _to_decimal = self._to_decimal + + self._detect_decimal = lambda value: _detect_decimal( + value.replace(self._decimal_char, ".") + ) + self._to_decimal = lambda value: _to_decimal( + value.replace(self._decimal_char, ".") + ) + + def _detect_decimal(self, value): + if "." in value: + return self._to_decimal(value) + else: + return int(value) + + _to_decimal = decimal.Decimal + + def _generate_connection_outputtype_handler(self): + """establish the default outputtypehandler established at the + connection level. + + """ + + dialect = self + cx_Oracle = dialect.dbapi + + number_handler = _OracleNUMBER( + asdecimal=True + )._cx_oracle_outputtypehandler(dialect) + float_handler = _OracleNUMBER( + asdecimal=False + )._cx_oracle_outputtypehandler(dialect) + + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): + + if ( + default_type == cx_Oracle.NUMBER + and default_type is not cx_Oracle.NATIVE_FLOAT + ): + if not dialect.coerce_to_decimal: + return None + elif precision == 0 and scale in (0, -127): + # ambiguous type, this occurs when selecting + # numbers from deep subqueries + return cursor.var( + cx_Oracle.STRING, + 255, + outconverter=dialect._detect_decimal, + arraysize=cursor.arraysize, + ) + elif precision and scale > 0: + return number_handler( + cursor, name, default_type, size, precision, scale + ) + else: + return float_handler( + cursor, name, default_type, size, precision, scale + ) + + # allow all strings to come back natively as Unicode + elif ( + dialect.coerce_to_unicode + and default_type + in ( + cx_Oracle.STRING, + cx_Oracle.FIXED_CHAR, + ) + and default_type is not cx_Oracle.CLOB + and default_type is not cx_Oracle.NCLOB + ): + if compat.py2k: + outconverter = processors.to_unicode_processor_factory( + dialect.encoding, errors=dialect.encoding_errors + ) + return cursor.var( + cx_Oracle.STRING, + size, + cursor.arraysize, + outconverter=outconverter, + ) + else: + return cursor.var( + util.text_type, + size, + cursor.arraysize, + **dialect._cursor_var_unicode_kwargs + ) + + elif dialect.auto_convert_lobs and default_type in ( + cx_Oracle.CLOB, + cx_Oracle.NCLOB, + ): + if compat.py2k: + outconverter = processors.to_unicode_processor_factory( + dialect.encoding, errors=dialect.encoding_errors + ) + return cursor.var( + cx_Oracle.LONG_STRING, + size, + cursor.arraysize, + outconverter=outconverter, + ) + else: + return cursor.var( + cx_Oracle.LONG_STRING, + size, + cursor.arraysize, + **dialect._cursor_var_unicode_kwargs + ) + + elif dialect.auto_convert_lobs and default_type in ( + cx_Oracle.BLOB, + ): + return cursor.var( + cx_Oracle.LONG_BINARY, + size, + cursor.arraysize, + ) + + return output_type_handler + + def on_connect(self): + + output_type_handler = self._generate_connection_outputtype_handler() + + def on_connect(conn): + conn.outputtypehandler = output_type_handler + + return on_connect + + def create_connect_args(self, url): + opts = dict(url.query) + + for opt in ("use_ansi", "auto_convert_lobs"): + if opt in opts: + util.warn_deprecated( + "cx_oracle dialect option %r should only be passed to " + "create_engine directly, not within the URL string" % opt, + version="1.3", + ) + util.coerce_kw_type(opts, opt, bool) + setattr(self, opt, opts.pop(opt)) + + database = url.database + service_name = opts.pop("service_name", None) + if database or service_name: + # if we have a database, then we have a remote host + port = url.port + if port: + port = int(port) + else: + port = 1521 + + if database and service_name: + raise exc.InvalidRequestError( + '"service_name" option shouldn\'t ' + 'be used with a "database" part of the url' + ) + if database: + makedsn_kwargs = {"sid": database} + if service_name: + makedsn_kwargs = {"service_name": service_name} + + dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs) + else: + # we have a local tnsname + dsn = url.host + + if dsn is not None: + opts["dsn"] = dsn + if url.password is not None: + opts["password"] = url.password + if url.username is not None: + opts["user"] = url.username + + if self._cx_oracle_threaded is not None: + opts.setdefault("threaded", self._cx_oracle_threaded) + + def convert_cx_oracle_constant(value): + if isinstance(value, util.string_types): + try: + int_val = int(value) + except ValueError: + value = value.upper() + return getattr(self.dbapi, value) + else: + return int_val + else: + return value + + util.coerce_kw_type(opts, "mode", convert_cx_oracle_constant) + util.coerce_kw_type(opts, "threaded", bool) + util.coerce_kw_type(opts, "events", bool) + util.coerce_kw_type(opts, "purity", convert_cx_oracle_constant) + return ([], opts) + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.version.split(".")) + + def is_disconnect(self, e, connection, cursor): + (error,) = e.args + if isinstance( + e, (self.dbapi.InterfaceError, self.dbapi.DatabaseError) + ) and "not connected" in str(e): + return True + + if hasattr(error, "code") and error.code in { + 28, + 3114, + 3113, + 3135, + 1033, + 2396, + }: + # ORA-00028: your session has been killed + # ORA-03114: not connected to ORACLE + # ORA-03113: end-of-file on communication channel + # ORA-03135: connection lost contact + # ORA-01033: ORACLE initialization or shutdown in progress + # ORA-02396: exceeded maximum idle time, please connect again + # TODO: Others ? + return True + + if re.match(r"^(?:DPI-1010|DPI-1080|DPY-1001|DPY-4011)", str(e)): + # DPI-1010: not connected + # DPI-1080: connection was closed by ORA-3113 + # python-oracledb's DPY-1001: not connected to database + # python-oracledb's DPY-4011: the database or network closed the + # connection + # TODO: others? + return True + + return False + + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified. + + """ + + id_ = random.randint(0, 2 ** 128) + return (0x1234, "%032x" % id_, "%032x" % 9) + + def do_executemany(self, cursor, statement, parameters, context=None): + if isinstance(parameters, tuple): + parameters = list(parameters) + cursor.executemany(statement, parameters) + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) + connection.connection.info["cx_oracle_xid"] = xid + + def do_prepare_twophase(self, connection, xid): + result = connection.connection.prepare() + connection.info["cx_oracle_prepared"] = result + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + self.do_rollback(connection.connection) + # TODO: need to end XA state here + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + + if not is_prepared: + self.do_commit(connection.connection) + else: + if recover: + raise NotImplementedError( + "2pc recovery not implemented for cx_Oracle" + ) + oci_prepared = connection.info["cx_oracle_prepared"] + if oci_prepared: + self.do_commit(connection.connection) + # TODO: need to end XA state here + + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + # not usually used, here to support if someone is modifying + # the dialect to use positional style + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + collection = ( + (key, dbtype) + for key, dbtype, sqltype in list_of_tuples + if dbtype + ) + + if not self.supports_unicode_binds: + # oracle 8 only + collection = ( + (self.dialect._encoder(key)[0], dbtype) + for key, dbtype in collection + ) + + cursor.setinputsizes(**{key: dbtype for key, dbtype in collection}) + + def do_recover_twophase(self, connection): + raise NotImplementedError( + "recover two phase query for cx_Oracle not implemented" + ) + + +dialect = OracleDialect_cx_oracle diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py new file mode 100644 index 0000000..74ad1f2 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -0,0 +1,160 @@ +from ... import create_engine +from ... import exc +from ...engine import url as sa_url +from ...testing.provision import configure_follower +from ...testing.provision import create_db +from ...testing.provision import drop_db +from ...testing.provision import follower_url_from_main +from ...testing.provision import log +from ...testing.provision import post_configure_engine +from ...testing.provision import run_reap_dbs +from ...testing.provision import set_default_schema_on_connection +from ...testing.provision import stop_test_class_outside_fixtures +from ...testing.provision import temp_table_keyword_args + + +@create_db.for_db("oracle") +def _oracle_create_db(cfg, eng, ident): + # NOTE: make sure you've run "ALTER DATABASE default tablespace users" or + # similar, so that the default tablespace is not "system"; reflection will + # fail otherwise + with eng.begin() as conn: + conn.exec_driver_sql("create user %s identified by xe" % ident) + conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident) + conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident) + conn.exec_driver_sql("grant dba to %s" % (ident,)) + conn.exec_driver_sql("grant unlimited tablespace to %s" % ident) + conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident) + conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident) + + +@configure_follower.for_db("oracle") +def _oracle_configure_follower(config, ident): + config.test_schema = "%s_ts1" % ident + config.test_schema_2 = "%s_ts2" % ident + + +def _ora_drop_ignore(conn, dbname): + try: + conn.exec_driver_sql("drop user %s cascade" % dbname) + log.info("Reaped db: %s", dbname) + return True + except exc.DatabaseError as err: + log.warning("couldn't drop db: %s", err) + return False + + +@drop_db.for_db("oracle") +def _oracle_drop_db(cfg, eng, ident): + with eng.begin() as conn: + # cx_Oracle seems to occasionally leak open connections when a large + # suite it run, even if we confirm we have zero references to + # connection objects. + # while there is a "kill session" command in Oracle, + # it unfortunately does not release the connection sufficiently. + _ora_drop_ignore(conn, ident) + _ora_drop_ignore(conn, "%s_ts1" % ident) + _ora_drop_ignore(conn, "%s_ts2" % ident) + + +@stop_test_class_outside_fixtures.for_db("oracle") +def stop_test_class_outside_fixtures(config, db, cls): + + try: + with db.begin() as conn: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 + conn.exec_driver_sql("purge recyclebin") + except exc.DatabaseError as err: + log.warning("purge recyclebin command failed: %s", err) + + # clear statement cache on all connections that were used + # https://github.com/oracle/python-cx_Oracle/issues/519 + + for cx_oracle_conn in _all_conns: + try: + sc = cx_oracle_conn.stmtcachesize + except db.dialect.dbapi.InterfaceError: + # connection closed + pass + else: + cx_oracle_conn.stmtcachesize = 0 + cx_oracle_conn.stmtcachesize = sc + _all_conns.clear() + + +_all_conns = set() + + +@post_configure_engine.for_db("oracle") +def _oracle_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + @event.listens_for(engine, "checkout") + def checkout(dbapi_con, con_record, con_proxy): + _all_conns.add(dbapi_con) + + @event.listens_for(engine, "checkin") + def checkin(dbapi_connection, connection_record): + # work around cx_Oracle issue: + # https://github.com/oracle/python-cx_Oracle/issues/530 + # invalidate oracle connections that had 2pc set up + if "cx_oracle_xid" in connection_record.info: + connection_record.invalidate() + + +@run_reap_dbs.for_db("oracle") +def _reap_oracle_dbs(url, idents): + log.info("db reaper connecting to %r", url) + eng = create_engine(url) + with eng.begin() as conn: + + log.info("identifiers in file: %s", ", ".join(idents)) + + to_reap = conn.exec_driver_sql( + "select u.username from all_users u where username " + "like 'TEST_%' and not exists (select username " + "from v$session where username=u.username)" + ) + all_names = {username.lower() for (username,) in to_reap} + to_drop = set() + for name in all_names: + if name.endswith("_ts1") or name.endswith("_ts2"): + continue + elif name in idents: + to_drop.add(name) + if "%s_ts1" % name in all_names: + to_drop.add("%s_ts1" % name) + if "%s_ts2" % name in all_names: + to_drop.add("%s_ts2" % name) + + dropped = total = 0 + for total, username in enumerate(to_drop, 1): + if _ora_drop_ignore(conn, username): + dropped += 1 + log.info( + "Dropped %d out of %d stale databases detected", dropped, total + ) + + +@follower_url_from_main.for_db("oracle") +def _oracle_follower_url_from_main(url, ident): + url = sa_url.make_url(url) + return url.set(username=ident, password="xe") + + +@temp_table_keyword_args.for_db("oracle") +def _oracle_temp_table_keyword_args(cfg, eng): + return { + "prefixes": ["GLOBAL TEMPORARY"], + "oracle_on_commit": "PRESERVE ROWS", + } + + +@set_default_schema_on_connection.for_db("oracle") +def _oracle_set_default_schema_on_connection( + cfg, dbapi_connection, schema_name +): + cursor = dbapi_connection.cursor() + cursor.execute("ALTER SESSION SET CURRENT_SCHEMA=%s" % schema_name) + cursor.close() diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py new file mode 100644 index 0000000..12d9e94 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -0,0 +1,117 @@ +# postgresql/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from . import base +from . import pg8000 # noqa +from . import psycopg2 # noqa +from . import psycopg2cffi # noqa +from . import pygresql # noqa +from . import pypostgresql # noqa +from .array import All +from .array import Any +from .array import ARRAY +from .array import array +from .base import BIGINT +from .base import BIT +from .base import BOOLEAN +from .base import BYTEA +from .base import CHAR +from .base import CIDR +from .base import CreateEnumType +from .base import DATE +from .base import DOUBLE_PRECISION +from .base import DropEnumType +from .base import ENUM +from .base import FLOAT +from .base import INET +from .base import INTEGER +from .base import INTERVAL +from .base import MACADDR +from .base import MONEY +from .base import NUMERIC +from .base import OID +from .base import REAL +from .base import REGCLASS +from .base import SMALLINT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import TSVECTOR +from .base import UUID +from .base import VARCHAR +from .dml import Insert +from .dml import insert +from .ext import aggregate_order_by +from .ext import array_agg +from .ext import ExcludeConstraint +from .hstore import HSTORE +from .hstore import hstore +from .json import JSON +from .json import JSONB +from .ranges import DATERANGE +from .ranges import INT4RANGE +from .ranges import INT8RANGE +from .ranges import NUMRANGE +from .ranges import TSRANGE +from .ranges import TSTZRANGE +from ...util import compat + +if compat.py3k: + from . import asyncpg # noqa + +base.dialect = dialect = psycopg2.dialect + + +__all__ = ( + "INTEGER", + "BIGINT", + "SMALLINT", + "VARCHAR", + "CHAR", + "TEXT", + "NUMERIC", + "FLOAT", + "REAL", + "INET", + "CIDR", + "UUID", + "BIT", + "MACADDR", + "MONEY", + "OID", + "REGCLASS", + "DOUBLE_PRECISION", + "TIMESTAMP", + "TIME", + "DATE", + "BYTEA", + "BOOLEAN", + "INTERVAL", + "ARRAY", + "ENUM", + "dialect", + "array", + "HSTORE", + "hstore", + "INT4RANGE", + "INT8RANGE", + "NUMRANGE", + "DATERANGE", + "TSVECTOR", + "TSRANGE", + "TSTZRANGE", + "JSON", + "JSONB", + "Any", + "All", + "DropEnumType", + "CreateEnumType", + "ExcludeConstraint", + "aggregate_order_by", + "array_agg", + "insert", + "Insert", +) diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py new file mode 100644 index 0000000..daf7c5d --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -0,0 +1,413 @@ +# postgresql/array.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import re + +from ... import types as sqltypes +from ... import util +from ...sql import coercions +from ...sql import expression +from ...sql import operators +from ...sql import roles + + +def Any(other, arrexpr, operator=operators.eq): + """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method. + See that method for details. + + """ + + return arrexpr.any(other, operator) + + +def All(other, arrexpr, operator=operators.eq): + """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method. + See that method for details. + + """ + + return arrexpr.all(other, operator) + + +class array(expression.ClauseList, expression.ColumnElement): + + """A PostgreSQL ARRAY literal. + + This is used to produce ARRAY literals in SQL expressions, e.g.:: + + from sqlalchemy.dialects.postgresql import array + from sqlalchemy.dialects import postgresql + from sqlalchemy import select, func + + stmt = select(array([1,2]) + array([3,4,5])) + + print(stmt.compile(dialect=postgresql.dialect())) + + Produces the SQL:: + + SELECT ARRAY[%(param_1)s, %(param_2)s] || + ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1 + + An instance of :class:`.array` will always have the datatype + :class:`_types.ARRAY`. The "inner" type of the array is inferred from + the values present, unless the ``type_`` keyword argument is passed:: + + array(['foo', 'bar'], type_=CHAR) + + Multidimensional arrays are produced by nesting :class:`.array` constructs. + The dimensionality of the final :class:`_types.ARRAY` + type is calculated by + recursively adding the dimensions of the inner :class:`_types.ARRAY` + type:: + + stmt = select( + array([ + array([1, 2]), array([3, 4]), array([column('q'), column('x')]) + ]) + ) + print(stmt.compile(dialect=postgresql.dialect())) + + Produces:: + + SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s], + ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1 + + .. versionadded:: 1.3.6 added support for multidimensional array literals + + .. seealso:: + + :class:`_postgresql.ARRAY` + + """ + + __visit_name__ = "array" + + stringify_dialect = "postgresql" + inherit_cache = True + + def __init__(self, clauses, **kw): + clauses = [ + coercions.expect(roles.ExpressionElementRole, c) for c in clauses + ] + + super(array, self).__init__(*clauses, **kw) + + self._type_tuple = [arg.type for arg in clauses] + main_type = kw.pop( + "type_", + self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE, + ) + + if isinstance(main_type, ARRAY): + self.type = ARRAY( + main_type.item_type, + dimensions=main_type.dimensions + 1 + if main_type.dimensions is not None + else 2, + ) + else: + self.type = ARRAY(main_type) + + @property + def _select_iterable(self): + return (self,) + + def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): + if _assume_scalar or operator is operators.getitem: + return expression.BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=self.type, + unique=True, + ) + + else: + return array( + [ + self._bind_param( + operator, o, _assume_scalar=True, type_=type_ + ) + for o in obj + ] + ) + + def self_group(self, against=None): + if against in (operators.any_op, operators.all_op, operators.getitem): + return expression.Grouping(self) + else: + return self + + +CONTAINS = operators.custom_op("@>", precedence=5, is_comparison=True) + +CONTAINED_BY = operators.custom_op("<@", precedence=5, is_comparison=True) + +OVERLAP = operators.custom_op("&&", precedence=5, is_comparison=True) + + +class ARRAY(sqltypes.ARRAY): + + """PostgreSQL ARRAY type. + + .. versionchanged:: 1.1 The :class:`_postgresql.ARRAY` type is now + a subclass of the core :class:`_types.ARRAY` type. + + The :class:`_postgresql.ARRAY` type is constructed in the same way + as the core :class:`_types.ARRAY` type; a member type is required, and a + number of dimensions is recommended if the type is to be used for more + than one dimension:: + + from sqlalchemy.dialects import postgresql + + mytable = Table("mytable", metadata, + Column("data", postgresql.ARRAY(Integer, dimensions=2)) + ) + + The :class:`_postgresql.ARRAY` type provides all operations defined on the + core :class:`_types.ARRAY` type, including support for "dimensions", + indexed access, and simple matching such as + :meth:`.types.ARRAY.Comparator.any` and + :meth:`.types.ARRAY.Comparator.all`. :class:`_postgresql.ARRAY` + class also + provides PostgreSQL-specific methods for containment operations, including + :meth:`.postgresql.ARRAY.Comparator.contains` + :meth:`.postgresql.ARRAY.Comparator.contained_by`, and + :meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.:: + + mytable.c.data.contains([1, 2]) + + The :class:`_postgresql.ARRAY` type may not be supported on all + PostgreSQL DBAPIs; it is currently known to work on psycopg2 only. + + Additionally, the :class:`_postgresql.ARRAY` + type does not work directly in + conjunction with the :class:`.ENUM` type. For a workaround, see the + special type at :ref:`postgresql_array_of_enum`. + + .. seealso:: + + :class:`_types.ARRAY` - base array type + + :class:`_postgresql.array` - produces a literal array value. + + """ + + class Comparator(sqltypes.ARRAY.Comparator): + + """Define comparison operations for :class:`_types.ARRAY`. + + Note that these operations are in addition to those provided + by the base :class:`.types.ARRAY.Comparator` class, including + :meth:`.types.ARRAY.Comparator.any` and + :meth:`.types.ARRAY.Comparator.all`. + + """ + + def contains(self, other, **kwargs): + """Boolean expression. Test if elements are a superset of the + elements of the argument array expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if elements are a proper subset of the + elements of the argument array expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def overlap(self, other): + """Boolean expression. Test if array has elements in common with + an argument array expression. + """ + return self.operate(OVERLAP, other, result_type=sqltypes.Boolean) + + comparator_factory = Comparator + + def __init__( + self, item_type, as_tuple=False, dimensions=None, zero_indexes=False + ): + """Construct an ARRAY. + + E.g.:: + + Column('myarray', ARRAY(Integer)) + + Arguments are: + + :param item_type: The data type of items of this array. Note that + dimensionality is irrelevant here, so multi-dimensional arrays like + ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as + ``ARRAY(ARRAY(Integer))`` or such. + + :param as_tuple=False: Specify whether return results + should be converted to tuples from lists. DBAPIs such + as psycopg2 return lists by default. When tuples are + returned, the results are hashable. + + :param dimensions: if non-None, the ARRAY will assume a fixed + number of dimensions. This will cause the DDL emitted for this + ARRAY to include the exact number of bracket clauses ``[]``, + and will also optimize the performance of the type overall. + Note that PG arrays are always implicitly "non-dimensioned", + meaning they can store any number of dimensions no matter how + they were declared. + + :param zero_indexes=False: when True, index values will be converted + between Python zero-based and PostgreSQL one-based indexes, e.g. + a value of one will be added to all index values before passing + to the database. + + .. versionadded:: 0.9.5 + + + """ + if isinstance(item_type, ARRAY): + raise ValueError( + "Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype" + ) + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.as_tuple = as_tuple + self.dimensions = dimensions + self.zero_indexes = zero_indexes + + @property + def hashable(self): + return self.as_tuple + + @property + def python_type(self): + return list + + def compare_values(self, x, y): + return x == y + + def _proc_array(self, arr, itemproc, dim, collection): + if dim is None: + arr = list(arr) + if ( + dim == 1 + or dim is None + and ( + # this has to be (list, tuple), or at least + # not hasattr('__iter__'), since Py3K strings + # etc. have __iter__ + not arr + or not isinstance(arr[0], (list, tuple)) + ) + ): + if itemproc: + return collection(itemproc(x) for x in arr) + else: + return collection(arr) + else: + return collection( + self._proc_array( + x, + itemproc, + dim - 1 if dim is not None else None, + collection, + ) + for x in arr + ) + + @util.memoized_property + def _against_native_enum(self): + return ( + isinstance(self.item_type, sqltypes.Enum) + and self.item_type.native_enum + ) + + def bind_expression(self, bindvalue): + return bindvalue + + def bind_processor(self, dialect): + item_proc = self.item_type.dialect_impl(dialect).bind_processor( + dialect + ) + + def process(value): + if value is None: + return value + else: + return self._proc_array( + value, item_proc, self.dimensions, list + ) + + return process + + def result_processor(self, dialect, coltype): + item_proc = self.item_type.dialect_impl(dialect).result_processor( + dialect, coltype + ) + + def process(value): + if value is None: + return value + else: + return self._proc_array( + value, + item_proc, + self.dimensions, + tuple if self.as_tuple else list, + ) + + if self._against_native_enum: + super_rp = process + pattern = re.compile(r"^{(.*)}$") + + def handle_raw_string(value): + inner = pattern.match(value).group(1) + return _split_enum_values(inner) + + def process(value): + if value is None: + return value + # isinstance(value, util.string_types) is required to handle + # the case where a TypeDecorator for and Array of Enum is + # used like was required in sa < 1.3.17 + return super_rp( + handle_raw_string(value) + if isinstance(value, util.string_types) + else value + ) + + return process + + +def _split_enum_values(array_string): + + if '"' not in array_string: + # no escape char is present so it can just split on the comma + return array_string.split(",") if array_string else [] + + # handles quoted strings from: + # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr' + # returns + # ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr'] + text = array_string.replace(r"\"", "_$ESC_QUOTE$_") + text = text.replace(r"\\", "\\") + result = [] + on_quotes = re.split(r'(")', text) + in_quotes = False + for tok in on_quotes: + if tok == '"': + in_quotes = not in_quotes + elif in_quotes: + result.append(tok.replace("_$ESC_QUOTE$_", '"')) + else: + result.extend(re.findall(r"([^\s,]+),?", tok)) + return result diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py new file mode 100644 index 0000000..305ad46 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -0,0 +1,1112 @@ +# postgresql/asyncpg.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS +# file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: postgresql+asyncpg + :name: asyncpg + :dbapi: asyncpg + :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...] + :url: https://magicstack.github.io/asyncpg/ + +The asyncpg dialect is SQLAlchemy's first Python asyncio dialect. + +Using a special asyncio mediation layer, the asyncpg dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname") + +The dialect can also be run as a "synchronous" dialect within the +:func:`_sa.create_engine` function, which will pass "await" calls into +an ad-hoc event loop. This mode of operation is of **limited use** +and is for special testing scenarios only. The mode can be enabled by +adding the SQLAlchemy-specific flag ``async_fallback`` to the URL +in conjunction with :func:`_sa.create_engine`:: + + # for testing purposes only; do not use in production! + engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true") + + +.. versionadded:: 1.4 + +.. note:: + + By default asyncpg does not decode the ``json`` and ``jsonb`` types and + returns them as strings. SQLAlchemy sets default type decoder for ``json`` + and ``jsonb`` types using the python builtin ``json.loads`` function. + The json implementation used can be changed by setting the attribute + ``json_deserializer`` when creating the engine with + :func:`create_engine` or :func:`create_async_engine`. + + +.. _asyncpg_prepared_statement_cache: + +Prepared Statement Cache +-------------------------- + +The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()`` +for all statements. The prepared statement objects are cached after +construction which appears to grant a 10% or more performance improvement for +statement invocation. The cache is on a per-DBAPI connection basis, which +means that the primary storage for prepared statements is within DBAPI +connections pooled within the connection pool. The size of this cache +defaults to 100 statements per DBAPI connection and may be adjusted using the +``prepared_statement_cache_size`` DBAPI argument (note that while this argument +is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the +asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect +argument):: + + + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500") + +To disable the prepared statement cache, use a value of zero:: + + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0") + +.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg. + + +.. warning:: The ``asyncpg`` database driver necessarily uses caches for + PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes + such as ``ENUM`` objects are changed via DDL operations. Additionally, + prepared statements themselves which are optionally cached by SQLAlchemy's + driver as described above may also become "stale" when DDL has been emitted + to the PostgreSQL database which modifies the tables or other objects + involved in a particular prepared statement. + + The SQLAlchemy asyncpg dialect will invalidate these caches within its local + process when statements that represent DDL are emitted on a local + connection, but this is only controllable within a single Python process / + database engine. If DDL changes are made from other database engines + and/or processes, a running application may encounter asyncpg exceptions + ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup + failed for type <oid>")`` if it refers to pooled database connections which + operated upon the previous structures. The SQLAlchemy asyncpg dialect will + recover from these error cases when the driver raises these exceptions by + clearing its internal caches as well as those of the asyncpg driver in + response to them, but cannot prevent them from being raised in the first + place if the cached prepared statement or asyncpg type caches have gone + stale, nor can it retry the statement as the PostgreSQL transaction is + invalidated when these errors occur. + +Disabling the PostgreSQL JIT to improve ENUM datatype handling +--------------------------------------------------------------- + +Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when +using PostgreSQL ENUM datatypes, where upon the creation of new database +connections, an expensive query may be emitted in order to retrieve metadata +regarding custom types which has been shown to negatively affect performance. +To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the +client using this setting passed to :func:`_asyncio.create_async_engine`:: + + engine = create_async_engine( + "postgresql+asyncpg://user:password@localhost/tmp", + connect_args={"server_settings": {"jit": "off"}}, + ) + +.. seealso:: + + https://github.com/MagicStack/asyncpg/issues/727 + +""" # noqa + +import collections +import decimal +import json as _py_json +import re +import time + +from . import json +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import ENUM +from .base import INTERVAL +from .base import OID +from .base import PGCompiler +from .base import PGDialect +from .base import PGExecutionContext +from .base import PGIdentifierPreparer +from .base import REGCLASS +from .base import UUID +from ... import exc +from ... import pool +from ... import processors +from ... import util +from ...engine import AdaptedConnection +from ...sql import sqltypes +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +try: + from uuid import UUID as _python_UUID # noqa +except ImportError: + _python_UUID = None + + +class AsyncpgTime(sqltypes.Time): + def get_dbapi_type(self, dbapi): + if self.timezone: + return dbapi.TIME_W_TZ + else: + return dbapi.TIME + + +class AsyncpgDate(sqltypes.Date): + def get_dbapi_type(self, dbapi): + return dbapi.DATE + + +class AsyncpgDateTime(sqltypes.DateTime): + def get_dbapi_type(self, dbapi): + if self.timezone: + return dbapi.TIMESTAMP_W_TZ + else: + return dbapi.TIMESTAMP + + +class AsyncpgBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.BOOLEAN + + +class AsyncPgInterval(INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + + return AsyncPgInterval(precision=interval.second_precision) + + +class AsyncPgEnum(ENUM): + def get_dbapi_type(self, dbapi): + return dbapi.ENUM + + +class AsyncpgInteger(sqltypes.Integer): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class AsyncpgBigInteger(sqltypes.BigInteger): + def get_dbapi_type(self, dbapi): + return dbapi.BIGINTEGER + + +class AsyncpgJSON(json.JSON): + def get_dbapi_type(self, dbapi): + return dbapi.JSON + + def result_processor(self, dialect, coltype): + return None + + +class AsyncpgJSONB(json.JSONB): + def get_dbapi_type(self, dbapi): + return dbapi.JSONB + + def result_processor(self, dialect, coltype): + return None + + +class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType): + def get_dbapi_type(self, dbapi): + raise NotImplementedError("should not be here") + + +class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.STRING + + +class AsyncpgJSONPathType(json.JSONPathType): + def bind_processor(self, dialect): + def process(value): + assert isinstance(value, util.collections_abc.Sequence) + tokens = [util.text_type(elem) for elem in value] + return tokens + + return process + + +class AsyncpgUUID(UUID): + def get_dbapi_type(self, dbapi): + return dbapi.UUID + + def bind_processor(self, dialect): + if not self.as_uuid and dialect.use_native_uuid: + + def process(value): + if value is not None: + value = _python_UUID(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not self.as_uuid and dialect.use_native_uuid: + + def process(value): + if value is not None: + value = str(value) + return value + + return process + + +class AsyncpgNumeric(sqltypes.Numeric): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class AsyncpgFloat(AsyncpgNumeric): + def get_dbapi_type(self, dbapi): + return dbapi.FLOAT + + +class AsyncpgREGCLASS(REGCLASS): + def get_dbapi_type(self, dbapi): + return dbapi.STRING + + +class AsyncpgOID(OID): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class PGExecutionContext_asyncpg(PGExecutionContext): + def handle_dbapi_exception(self, e): + if isinstance( + e, + ( + self.dialect.dbapi.InvalidCachedStatementError, + self.dialect.dbapi.InternalServerError, + ), + ): + self.dialect._invalidate_schema_cache() + + def pre_exec(self): + if self.isddl: + self.dialect._invalidate_schema_cache() + + self.cursor._invalidate_schema_cache_asof = ( + self.dialect._invalidate_schema_cache_asof + ) + + if not self.compiled: + return + + # we have to exclude ENUM because "enum" not really a "type" + # we can cast to, it has to be the name of the type itself. + # for now we just omit it from casting + self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM} + + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class PGCompiler_asyncpg(PGCompiler): + pass + + +class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer): + pass + + +class AsyncAdapt_asyncpg_cursor: + __slots__ = ( + "_adapt_connection", + "_connection", + "_rows", + "description", + "arraysize", + "rowcount", + "_inputsizes", + "_cursor", + "_invalidate_schema_cache_asof", + ) + + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self._rows = [] + self._cursor = None + self.description = None + self.arraysize = 1 + self.rowcount = -1 + self._inputsizes = None + self._invalidate_schema_cache_asof = 0 + + def close(self): + self._rows[:] = [] + + def _handle_exception(self, error): + self._adapt_connection._handle_exception(error) + + def _parameter_placeholders(self, params): + if not self._inputsizes: + return tuple("$%d" % idx for idx, _ in enumerate(params, 1)) + else: + return tuple( + "$%d::%s" % (idx, typ) if typ else "$%d" % idx + for idx, typ in enumerate( + (_pg_types.get(typ) for typ in self._inputsizes), 1 + ) + ) + + async def _prepare_and_execute(self, operation, parameters): + adapt_connection = self._adapt_connection + + async with adapt_connection._execute_mutex: + + if not adapt_connection._started: + await adapt_connection._start_transaction() + + if parameters is not None: + operation = operation % self._parameter_placeholders( + parameters + ) + else: + parameters = () + + try: + prepared_stmt, attributes = await adapt_connection._prepare( + operation, self._invalidate_schema_cache_asof + ) + + if attributes: + self.description = [ + ( + attr.name, + attr.type.oid, + None, + None, + None, + None, + None, + ) + for attr in attributes + ] + else: + self.description = None + + if self.server_side: + self._cursor = await prepared_stmt.cursor(*parameters) + self.rowcount = -1 + else: + self._rows = await prepared_stmt.fetch(*parameters) + status = prepared_stmt.get_statusmsg() + + reg = re.match( + r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status + ) + if reg: + self.rowcount = int(reg.group(1)) + else: + self.rowcount = -1 + + except Exception as error: + self._handle_exception(error) + + async def _executemany(self, operation, seq_of_parameters): + adapt_connection = self._adapt_connection + + async with adapt_connection._execute_mutex: + await adapt_connection._check_type_cache_invalidation( + self._invalidate_schema_cache_asof + ) + + if not adapt_connection._started: + await adapt_connection._start_transaction() + + operation = operation % self._parameter_placeholders( + seq_of_parameters[0] + ) + + try: + return await self._connection.executemany( + operation, seq_of_parameters + ) + except Exception as error: + self._handle_exception(error) + + def execute(self, operation, parameters=None): + self._adapt_connection.await_( + self._prepare_and_execute(operation, parameters) + ) + + def executemany(self, operation, seq_of_parameters): + return self._adapt_connection.await_( + self._executemany(operation, seq_of_parameters) + ) + + def setinputsizes(self, *inputsizes): + self._inputsizes = inputsizes + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): + + server_side = True + __slots__ = ("_rowbuffer",) + + def __init__(self, adapt_connection): + super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection) + self._rowbuffer = None + + def close(self): + self._cursor = None + self._rowbuffer = None + + def _buffer_rows(self): + new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) + self._rowbuffer = collections.deque(new_rows) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._rowbuffer: + self._buffer_rows() + + while True: + while self._rowbuffer: + yield self._rowbuffer.popleft() + + self._buffer_rows() + if not self._rowbuffer: + break + + def fetchone(self): + if not self._rowbuffer: + self._buffer_rows() + if not self._rowbuffer: + return None + return self._rowbuffer.popleft() + + def fetchmany(self, size=None): + if size is None: + return self.fetchall() + + if not self._rowbuffer: + self._buffer_rows() + + buf = list(self._rowbuffer) + lb = len(buf) + if size > lb: + buf.extend( + self._adapt_connection.await_(self._cursor.fetch(size - lb)) + ) + + result = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + return result + + def fetchall(self): + ret = list(self._rowbuffer) + list( + self._adapt_connection.await_(self._all()) + ) + self._rowbuffer.clear() + return ret + + async def _all(self): + rows = [] + + # TODO: looks like we have to hand-roll some kind of batching here. + # hardcoding for the moment but this should be improved. + while True: + batch = await self._cursor.fetch(1000) + if batch: + rows.extend(batch) + continue + else: + break + return rows + + def executemany(self, operation, seq_of_parameters): + raise NotImplementedError( + "server side cursor doesn't support executemany yet" + ) + + +class AsyncAdapt_asyncpg_connection(AdaptedConnection): + __slots__ = ( + "dbapi", + "_connection", + "isolation_level", + "_isolation_setting", + "readonly", + "deferrable", + "_transaction", + "_started", + "_prepared_statement_cache", + "_invalidate_schema_cache_asof", + "_execute_mutex", + ) + + await_ = staticmethod(await_only) + + def __init__(self, dbapi, connection, prepared_statement_cache_size=100): + self.dbapi = dbapi + self._connection = connection + self.isolation_level = self._isolation_setting = "read_committed" + self.readonly = False + self.deferrable = False + self._transaction = None + self._started = False + self._invalidate_schema_cache_asof = time.time() + self._execute_mutex = asyncio.Lock() + + if prepared_statement_cache_size: + self._prepared_statement_cache = util.LRUCache( + prepared_statement_cache_size + ) + else: + self._prepared_statement_cache = None + + async def _check_type_cache_invalidation(self, invalidate_timestamp): + if invalidate_timestamp > self._invalidate_schema_cache_asof: + await self._connection.reload_schema_state() + self._invalidate_schema_cache_asof = invalidate_timestamp + + async def _prepare(self, operation, invalidate_timestamp): + await self._check_type_cache_invalidation(invalidate_timestamp) + + cache = self._prepared_statement_cache + if cache is None: + prepared_stmt = await self._connection.prepare(operation) + attributes = prepared_stmt.get_attributes() + return prepared_stmt, attributes + + # asyncpg uses a type cache for the "attributes" which seems to go + # stale independently of the PreparedStatement itself, so place that + # collection in the cache as well. + if operation in cache: + prepared_stmt, attributes, cached_timestamp = cache[operation] + + # preparedstatements themselves also go stale for certain DDL + # changes such as size of a VARCHAR changing, so there is also + # a cross-connection invalidation timestamp + if cached_timestamp > invalidate_timestamp: + return prepared_stmt, attributes + + prepared_stmt = await self._connection.prepare(operation) + attributes = prepared_stmt.get_attributes() + cache[operation] = (prepared_stmt, attributes, time.time()) + + return prepared_stmt, attributes + + def _handle_exception(self, error): + if self._connection.is_closed(): + self._transaction = None + self._started = False + + if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): + exception_mapping = self.dbapi._asyncpg_error_translate + + for super_ in type(error).__mro__: + if super_ in exception_mapping: + translated_error = exception_mapping[super_]( + "%s: %s" % (type(error), error) + ) + translated_error.pgcode = ( + translated_error.sqlstate + ) = getattr(error, "sqlstate", None) + raise translated_error from error + else: + raise error + else: + raise error + + @property + def autocommit(self): + return self.isolation_level == "autocommit" + + @autocommit.setter + def autocommit(self, value): + if value: + self.isolation_level = "autocommit" + else: + self.isolation_level = self._isolation_setting + + def set_isolation_level(self, level): + if self._started: + self.rollback() + self.isolation_level = self._isolation_setting = level + + async def _start_transaction(self): + if self.isolation_level == "autocommit": + return + + try: + self._transaction = self._connection.transaction( + isolation=self.isolation_level, + readonly=self.readonly, + deferrable=self.deferrable, + ) + await self._transaction.start() + except Exception as error: + self._handle_exception(error) + else: + self._started = True + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_asyncpg_ss_cursor(self) + else: + return AsyncAdapt_asyncpg_cursor(self) + + def rollback(self): + if self._started: + try: + self.await_(self._transaction.rollback()) + except Exception as error: + self._handle_exception(error) + finally: + self._transaction = None + self._started = False + + def commit(self): + if self._started: + try: + self.await_(self._transaction.commit()) + except Exception as error: + self._handle_exception(error) + finally: + self._transaction = None + self._started = False + + def close(self): + self.rollback() + + self.await_(self._connection.close()) + + +class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_asyncpg_dbapi: + def __init__(self, asyncpg): + self.asyncpg = asyncpg + self.paramstyle = "format" + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + prepared_statement_cache_size = kw.pop( + "prepared_statement_cache_size", 100 + ) + if util.asbool(async_fallback): + return AsyncAdaptFallback_asyncpg_connection( + self, + await_fallback(self.asyncpg.connect(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, + ) + else: + return AsyncAdapt_asyncpg_connection( + self, + await_only(self.asyncpg.connect(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, + ) + + class Error(Exception): + pass + + class Warning(Exception): # noqa + pass + + class InterfaceError(Error): + pass + + class DatabaseError(Error): + pass + + class InternalError(DatabaseError): + pass + + class OperationalError(DatabaseError): + pass + + class ProgrammingError(DatabaseError): + pass + + class IntegrityError(DatabaseError): + pass + + class DataError(DatabaseError): + pass + + class NotSupportedError(DatabaseError): + pass + + class InternalServerError(InternalError): + pass + + class InvalidCachedStatementError(NotSupportedError): + def __init__(self, message): + super( + AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self + ).__init__( + message + " (SQLAlchemy asyncpg dialect will now invalidate " + "all prepared caches in response to this exception)", + ) + + @util.memoized_property + def _asyncpg_error_translate(self): + import asyncpg + + return { + asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501 + asyncpg.exceptions.PostgresError: self.Error, + asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError, + asyncpg.exceptions.InterfaceError: self.InterfaceError, + asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501 + asyncpg.exceptions.InternalServerError: self.InternalServerError, + } + + def Binary(self, value): + return value + + STRING = util.symbol("STRING") + TIMESTAMP = util.symbol("TIMESTAMP") + TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ") + TIME = util.symbol("TIME") + TIME_W_TZ = util.symbol("TIME_W_TZ") + DATE = util.symbol("DATE") + INTERVAL = util.symbol("INTERVAL") + NUMBER = util.symbol("NUMBER") + FLOAT = util.symbol("FLOAT") + BOOLEAN = util.symbol("BOOLEAN") + INTEGER = util.symbol("INTEGER") + BIGINTEGER = util.symbol("BIGINTEGER") + BYTES = util.symbol("BYTES") + DECIMAL = util.symbol("DECIMAL") + JSON = util.symbol("JSON") + JSONB = util.symbol("JSONB") + ENUM = util.symbol("ENUM") + UUID = util.symbol("UUID") + BYTEA = util.symbol("BYTEA") + + DATETIME = TIMESTAMP + BINARY = BYTEA + + +_pg_types = { + AsyncAdapt_asyncpg_dbapi.STRING: "varchar", + AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp", + AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone", + AsyncAdapt_asyncpg_dbapi.DATE: "date", + AsyncAdapt_asyncpg_dbapi.TIME: "time", + AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone", + AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval", + AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric", + AsyncAdapt_asyncpg_dbapi.FLOAT: "float", + AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool", + AsyncAdapt_asyncpg_dbapi.INTEGER: "integer", + AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint", + AsyncAdapt_asyncpg_dbapi.BYTES: "bytes", + AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal", + AsyncAdapt_asyncpg_dbapi.JSON: "json", + AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb", + AsyncAdapt_asyncpg_dbapi.ENUM: "enum", + AsyncAdapt_asyncpg_dbapi.UUID: "uuid", + AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea", +} + + +class PGDialect_asyncpg(PGDialect): + driver = "asyncpg" + supports_statement_cache = True + + supports_unicode_statements = True + supports_server_side_cursors = True + + supports_unicode_binds = True + + default_paramstyle = "format" + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_asyncpg + statement_compiler = PGCompiler_asyncpg + preparer = PGIdentifierPreparer_asyncpg + + use_setinputsizes = True + + use_native_uuid = True + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Time: AsyncpgTime, + sqltypes.Date: AsyncpgDate, + sqltypes.DateTime: AsyncpgDateTime, + sqltypes.Interval: AsyncPgInterval, + INTERVAL: AsyncPgInterval, + UUID: AsyncpgUUID, + sqltypes.Boolean: AsyncpgBoolean, + sqltypes.Integer: AsyncpgInteger, + sqltypes.BigInteger: AsyncpgBigInteger, + sqltypes.Numeric: AsyncpgNumeric, + sqltypes.Float: AsyncpgFloat, + sqltypes.JSON: AsyncpgJSON, + json.JSONB: AsyncpgJSONB, + sqltypes.JSON.JSONPathType: AsyncpgJSONPathType, + sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType, + sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType, + sqltypes.Enum: AsyncPgEnum, + OID: AsyncpgOID, + REGCLASS: AsyncpgREGCLASS, + }, + ) + is_async = True + _invalidate_schema_cache_asof = 0 + + def _invalidate_schema_cache(self): + self._invalidate_schema_cache_asof = time.time() + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + @classmethod + def dbapi(cls): + return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg")) + + @util.memoized_property + def _isolation_lookup(self): + return { + "AUTOCOMMIT": "autocommit", + "READ COMMITTED": "read_committed", + "REPEATABLE READ": "repeatable_read", + "SERIALIZABLE": "serializable", + } + + def set_isolation_level(self, connection, level): + try: + level = self._isolation_lookup[level.replace("_", " ")] + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ), + replace_context=err, + ) + + connection.set_isolation_level(level) + + def set_readonly(self, connection, value): + connection.readonly = value + + def get_readonly(self, connection): + return connection.readonly + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + + opts.update(url.query) + util.coerce_kw_type(opts, "prepared_statement_cache_size", int) + util.coerce_kw_type(opts, "port", int) + return ([], opts) + + @classmethod + def get_pool_class(cls, url): + + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def is_disconnect(self, e, connection, cursor): + if connection: + return connection._connection.is_closed() + else: + return isinstance( + e, self.dbapi.InterfaceError + ) and "connection is closed" in str(e) + + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + cursor.setinputsizes( + **{ + key: dbtype + for key, dbtype, sqltype in list_of_tuples + if dbtype + } + ) + + async def setup_asyncpg_json_codec(self, conn): + """set up JSON codec for asyncpg. + + This occurs for all new connections and + can be overridden by third party dialects. + + .. versionadded:: 1.4.27 + + """ + + asyncpg_connection = conn._connection + deserializer = self._json_deserializer or _py_json.loads + + def _json_decoder(bin_value): + return deserializer(bin_value.decode()) + + await asyncpg_connection.set_type_codec( + "json", + encoder=str.encode, + decoder=_json_decoder, + schema="pg_catalog", + format="binary", + ) + + async def setup_asyncpg_jsonb_codec(self, conn): + """set up JSONB codec for asyncpg. + + This occurs for all new connections and + can be overridden by third party dialects. + + .. versionadded:: 1.4.27 + + """ + + asyncpg_connection = conn._connection + deserializer = self._json_deserializer or _py_json.loads + + def _jsonb_encoder(str_value): + # \x01 is the prefix for jsonb used by PostgreSQL. + # asyncpg requires it when format='binary' + return b"\x01" + str_value.encode() + + deserializer = self._json_deserializer or _py_json.loads + + def _jsonb_decoder(bin_value): + # the byte is the \x01 prefix for jsonb used by PostgreSQL. + # asyncpg returns it when format='binary' + return deserializer(bin_value[1:].decode()) + + await asyncpg_connection.set_type_codec( + "jsonb", + encoder=_jsonb_encoder, + decoder=_jsonb_decoder, + schema="pg_catalog", + format="binary", + ) + + def on_connect(self): + """on_connect for asyncpg + + A major component of this for asyncpg is to set up type decoders at the + asyncpg level. + + See https://github.com/MagicStack/asyncpg/issues/623 for + notes on JSON/JSONB implementation. + + """ + + super_connect = super(PGDialect_asyncpg, self).on_connect() + + def connect(conn): + conn.await_(self.setup_asyncpg_json_codec(conn)) + conn.await_(self.setup_asyncpg_jsonb_codec(conn)) + if super_connect is not None: + super_connect(conn) + + return connect + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = PGDialect_asyncpg diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py new file mode 100644 index 0000000..eb84170 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -0,0 +1,4651 @@ +# postgresql/base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" +.. dialect:: postgresql + :name: PostgreSQL + :full_support: 9.6, 10, 11, 12, 13, 14 + :normal_support: 9.6+ + :best_effort: 8+ + +.. _postgresql_sequences: + +Sequences/SERIAL/IDENTITY +------------------------- + +PostgreSQL supports sequences, and SQLAlchemy uses these as the default means +of creating new primary key values for integer-based primary key columns. When +creating tables, SQLAlchemy will issue the ``SERIAL`` datatype for +integer-based primary key columns, which generates a sequence and server side +default corresponding to the column. + +To specify a specific named sequence to be used for primary key generation, +use the :func:`~sqlalchemy.schema.Sequence` construct:: + + Table('sometable', metadata, + Column('id', Integer, Sequence('some_id_seq'), primary_key=True) + ) + +When SQLAlchemy issues a single INSERT statement, to fulfill the contract of +having the "last insert identifier" available, a RETURNING clause is added to +the INSERT statement which specifies the primary key columns should be +returned after the statement completes. The RETURNING functionality only takes +place if PostgreSQL 8.2 or later is in use. As a fallback approach, the +sequence, whether specified explicitly or implicitly via ``SERIAL``, is +executed independently beforehand, the returned value to be used in the +subsequent insert. Note that when an +:func:`~sqlalchemy.sql.expression.insert()` construct is executed using +"executemany" semantics, the "last inserted identifier" functionality does not +apply; no RETURNING clause is emitted nor is the sequence pre-executed in this +case. + +To force the usage of RETURNING by default off, specify the flag +``implicit_returning=False`` to :func:`_sa.create_engine`. + +PostgreSQL 10 and above IDENTITY columns +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL 10 and above have a new IDENTITY feature that supersedes the use +of SERIAL. The :class:`_schema.Identity` construct in a +:class:`_schema.Column` can be used to control its behavior:: + + from sqlalchemy import Table, Column, MetaData, Integer, Computed + + metadata = MetaData() + + data = Table( + "data", + metadata, + Column( + 'id', Integer, Identity(start=42, cycle=True), primary_key=True + ), + Column('data', String) + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE data ( + id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 42 CYCLE), + data VARCHAR, + PRIMARY KEY (id) + ) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the option of an autoincrementing + column. + +.. note:: + + Previous versions of SQLAlchemy did not have built-in support for rendering + of IDENTITY, and could use the following compilation hook to replace + occurrences of SERIAL with IDENTITY:: + + from sqlalchemy.schema import CreateColumn + from sqlalchemy.ext.compiler import compiles + + + @compiles(CreateColumn, 'postgresql') + def use_identity(element, compiler, **kw): + text = compiler.visit_create_column(element, **kw) + text = text.replace( + "SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY" + ) + return text + + Using the above, a table such as:: + + t = Table( + 't', m, + Column('id', Integer, primary_key=True), + Column('data', String) + ) + + Will generate on the backing database as:: + + CREATE TABLE t ( + id INT GENERATED BY DEFAULT AS IDENTITY, + data VARCHAR, + PRIMARY KEY (id) + ) + +.. _postgresql_ss_cursors: + +Server Side Cursors +------------------- + +Server-side cursor support is available for the psycopg2, asyncpg +dialects and may also be available in others. + +Server side cursors are enabled on a per-statement basis by using the +:paramref:`.Connection.execution_options.stream_results` connection execution +option:: + + with engine.connect() as conn: + result = conn.execution_options(stream_results=True).execute(text("select * from table")) + +Note that some kinds of SQL statements may not be supported with +server side cursors; generally, only SQL statements that return rows should be +used with this option. + +.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated + and will be removed in a future release. Please use the + :paramref:`_engine.Connection.stream_results` execution option for + unbuffered cursor support. + +.. seealso:: + + :ref:`engine_stream_results` + +.. _postgresql_isolation_level: + +Transaction Isolation Level +--------------------------- + +Most SQLAlchemy dialects support setting of transaction isolation level +using the :paramref:`_sa.create_engine.isolation_level` parameter +at the :func:`_sa.create_engine` level, and at the :class:`_engine.Connection` +level via the :paramref:`.Connection.execution_options.isolation_level` +parameter. + +For PostgreSQL dialects, this feature works either by making use of the +DBAPI-specific features, such as psycopg2's isolation level flags which will +embed the isolation level setting inline with the ``"BEGIN"`` statement, or for +DBAPIs with no direct support by emitting ``SET SESSION CHARACTERISTICS AS +TRANSACTION ISOLATION LEVEL <level>`` ahead of the ``"BEGIN"`` statement +emitted by the DBAPI. For the special AUTOCOMMIT isolation level, +DBAPI-specific techniques are used which is typically an ``.autocommit`` +flag on the DBAPI connection object. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+pg8000://scott:tiger@localhost/test", + isolation_level = "REPEATABLE READ" + ) + +To set using per-connection execution options:: + + with engine.connect() as conn: + conn = conn.execution_options( + isolation_level="REPEATABLE READ" + ) + with conn.begin(): + # ... work with transaction + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +Valid values for ``isolation_level`` on most PostgreSQL dialects include: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`dbapi_autocommit` + + :ref:`postgresql_readonly_deferrable` + + :ref:`psycopg2_isolation_level` + + :ref:`pg8000_isolation_level` + +.. _postgresql_readonly_deferrable: + +Setting READ ONLY / DEFERRABLE +------------------------------ + +Most PostgreSQL dialects support setting the "READ ONLY" and "DEFERRABLE" +characteristics of the transaction, which is in addition to the isolation level +setting. These two attributes can be established either in conjunction with or +independently of the isolation level by passing the ``postgresql_readonly`` and +``postgresql_deferrable`` flags with +:meth:`_engine.Connection.execution_options`. The example below illustrates +passing the ``"SERIALIZABLE"`` isolation level at the same time as setting +"READ ONLY" and "DEFERRABLE":: + + with engine.connect() as conn: + conn = conn.execution_options( + isolation_level="SERIALIZABLE", + postgresql_readonly=True, + postgresql_deferrable=True + ) + with conn.begin(): + # ... work with transaction + +Note that some DBAPIs such as asyncpg only support "readonly" with +SERIALIZABLE isolation. + +.. versionadded:: 1.4 added support for the ``postgresql_readonly`` + and ``postgresql_deferrable`` execution options. + +.. _postgresql_alternate_search_path: + +Setting Alternate Search Paths on Connect +------------------------------------------ + +The PostgreSQL ``search_path`` variable refers to the list of schema names +that will be implicitly referred towards when a particular table or other +object is referenced in a SQL statement. As detailed in the next section +:ref:`postgresql_schema_reflection`, SQLAlchemy is generally organized around +the concept of keeping this variable at its default value of ``public``, +however, in order to have it set to any arbitrary name or names when connections +are used automatically, the "SET SESSION search_path" command may be invoked +for all connections in a pool using the following event handler, as discussed +at :ref:`schema_set_default_connections`:: + + from sqlalchemy import event + from sqlalchemy import create_engine + + engine = create_engine("postgresql+psycopg2://scott:tiger@host/dbname") + + @event.listens_for(engine, "connect", insert=True) + def set_search_path(dbapi_connection, connection_record): + existing_autocommit = dbapi_connection.autocommit + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + cursor.execute("SET SESSION search_path='%s'" % schema_name) + cursor.close() + dbapi_connection.autocommit = existing_autocommit + +The reason the recipe is complicated by use of the ``.autocommit`` DBAPI +attribute is so that when the ``SET SESSION search_path`` directive is invoked, +it is invoked outside of the scope of any transaction and therefore will not +be reverted when the DBAPI connection has a rollback. + +.. seealso:: + + :ref:`schema_set_default_connections` - in the :ref:`metadata_toplevel` documentation + + + + +.. _postgresql_schema_reflection: + +Remote-Schema Table Introspection and PostgreSQL search_path +------------------------------------------------------------ + +.. admonition:: Section Best Practices Summarized + + keep the ``search_path`` variable set to its default of ``public``, without + any other schema names. For other schema names, name these explicitly + within :class:`_schema.Table` definitions. Alternatively, the + ``postgresql_ignore_search_path`` option will cause all reflected + :class:`_schema.Table` objects to have a :attr:`_schema.Table.schema` + attribute set up. + +The PostgreSQL dialect can reflect tables from any schema, as outlined in +:ref:`metadata_reflection_schemas`. + +With regards to tables which these :class:`_schema.Table` +objects refer to via foreign key constraint, a decision must be made as to how +the ``.schema`` is represented in those remote tables, in the case where that +remote schema name is also a member of the current +`PostgreSQL search path +<https://www.postgresql.org/docs/current/static/ddl-schemas.html#DDL-SCHEMAS-PATH>`_. + +By default, the PostgreSQL dialect mimics the behavior encouraged by +PostgreSQL's own ``pg_get_constraintdef()`` builtin procedure. This function +returns a sample definition for a particular foreign key constraint, +omitting the referenced schema name from that definition when the name is +also in the PostgreSQL schema search path. The interaction below +illustrates this behavior:: + + test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY); + CREATE TABLE + test=> CREATE TABLE referring( + test(> id INTEGER PRIMARY KEY, + test(> referred_id INTEGER REFERENCES test_schema.referred(id)); + CREATE TABLE + test=> SET search_path TO public, test_schema; + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n + test-> ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f' + test-> ; + pg_get_constraintdef + --------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES referred(id) + (1 row) + +Above, we created a table ``referred`` as a member of the remote schema +``test_schema``, however when we added ``test_schema`` to the +PG ``search_path`` and then asked ``pg_get_constraintdef()`` for the +``FOREIGN KEY`` syntax, ``test_schema`` was not included in the output of +the function. + +On the other hand, if we set the search path back to the typical default +of ``public``:: + + test=> SET search_path TO public; + SET + +The same query against ``pg_get_constraintdef()`` now returns the fully +schema-qualified name for us:: + + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n + test-> ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f'; + pg_get_constraintdef + --------------------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES test_schema.referred(id) + (1 row) + +SQLAlchemy will by default use the return value of ``pg_get_constraintdef()`` +in order to determine the remote schema name. That is, if our ``search_path`` +were set to include ``test_schema``, and we invoked a table +reflection process as follows:: + + >>> from sqlalchemy import Table, MetaData, create_engine, text + >>> engine = create_engine("postgresql://scott:tiger@localhost/test") + >>> with engine.connect() as conn: + ... conn.execute(text("SET search_path TO test_schema, public")) + ... metadata_obj = MetaData() + ... referring = Table('referring', metadata_obj, + ... autoload_with=conn) + ... + <sqlalchemy.engine.result.CursorResult object at 0x101612ed0> + +The above process would deliver to the :attr:`_schema.MetaData.tables` +collection +``referred`` table named **without** the schema:: + + >>> metadata_obj.tables['referred'].schema is None + True + +To alter the behavior of reflection such that the referred schema is +maintained regardless of the ``search_path`` setting, use the +``postgresql_ignore_search_path`` option, which can be specified as a +dialect-specific argument to both :class:`_schema.Table` as well as +:meth:`_schema.MetaData.reflect`:: + + >>> with engine.connect() as conn: + ... conn.execute(text("SET search_path TO test_schema, public")) + ... metadata_obj = MetaData() + ... referring = Table('referring', metadata_obj, + ... autoload_with=conn, + ... postgresql_ignore_search_path=True) + ... + <sqlalchemy.engine.result.CursorResult object at 0x1016126d0> + +We will now have ``test_schema.referred`` stored as schema-qualified:: + + >>> metadata_obj.tables['test_schema.referred'].schema + 'test_schema' + +.. sidebar:: Best Practices for PostgreSQL Schema reflection + + The description of PostgreSQL schema reflection behavior is complex, and + is the product of many years of dealing with widely varied use cases and + user preferences. But in fact, there's no need to understand any of it if + you just stick to the simplest use pattern: leave the ``search_path`` set + to its default of ``public`` only, never refer to the name ``public`` as + an explicit schema name otherwise, and refer to all other schema names + explicitly when building up a :class:`_schema.Table` object. The options + described here are only for those users who can't, or prefer not to, stay + within these guidelines. + +Note that **in all cases**, the "default" schema is always reflected as +``None``. The "default" schema on PostgreSQL is that which is returned by the +PostgreSQL ``current_schema()`` function. On a typical PostgreSQL +installation, this is the name ``public``. So a table that refers to another +which is in the ``public`` (i.e. default) schema will always have the +``.schema`` attribute set to ``None``. + +.. seealso:: + + :ref:`reflection_schema_qualified_interaction` - discussion of the issue + from a backend-agnostic perspective + + `The Schema Search Path + <https://www.postgresql.org/docs/current/static/ddl-schemas.html#DDL-SCHEMAS-PATH>`_ + - on the PostgreSQL website. + +INSERT/UPDATE...RETURNING +------------------------- + +The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and +``DELETE..RETURNING`` syntaxes. ``INSERT..RETURNING`` is used by default +for single-row INSERT statements in order to fetch newly generated +primary key identifiers. To specify an explicit ``RETURNING`` clause, +use the :meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\ + values(name='foo') + print(result.fetchall()) + + # UPDATE..RETURNING + result = table.update().returning(table.c.col1, table.c.col2).\ + where(table.c.name=='foo').values(name='bar') + print(result.fetchall()) + + # DELETE..RETURNING + result = table.delete().returning(table.c.col1, table.c.col2).\ + where(table.c.name=='foo') + print(result.fetchall()) + +.. _postgresql_insert_on_conflict: + +INSERT...ON CONFLICT (Upsert) +------------------------------ + +Starting with version 9.5, PostgreSQL allows "upserts" (update or insert) of +rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` statement. A +candidate row will only be inserted if that row does not violate any unique +constraints. In the case of a unique constraint violation, a secondary action +can occur which can be either "DO UPDATE", indicating that the data in the +target row should be updated, or "DO NOTHING", which indicates to silently skip +this row. + +Conflicts are determined using existing unique constraints and indexes. These +constraints may be identified either using their name as stated in DDL, +or they may be inferred by stating the columns and conditions that comprise +the indexes. + +SQLAlchemy provides ``ON CONFLICT`` support via the PostgreSQL-specific +:func:`_postgresql.insert()` function, which provides +the generative methods :meth:`_postgresql.Insert.on_conflict_do_update` +and :meth:`~.postgresql.Insert.on_conflict_do_nothing`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.postgresql import insert + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + ... index_elements=['id'] + ... ) + >>> print(do_nothing_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO NOTHING + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='pk_my_table', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT pk_my_table DO UPDATE SET data = %(param_1)s + +.. versionadded:: 1.1 + +.. seealso:: + + `INSERT .. ON CONFLICT + <https://www.postgresql.org/docs/current/static/sql-insert.html#SQL-ON-CONFLICT>`_ + - in the PostgreSQL documentation. + +Specifying the Target +^^^^^^^^^^^^^^^^^^^^^ + +Both methods supply the "target" of the conflict using either the +named constraint or by column inference: + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` argument + specifies a sequence containing string column names, :class:`_schema.Column` + objects, and/or SQL expression elements, which would identify a unique + index: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=[my_table.c.id], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +* When using :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` to + infer an index, a partial index can be inferred by also specifying the + use the :paramref:`_postgresql.Insert.on_conflict_do_update.index_where` parameter: + + .. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + >>> stmt = stmt.on_conflict_do_update( + ... index_elements=[my_table.c.user_email], + ... index_where=my_table.c.user_email.like('%@gmail.com'), + ... set_=dict(data=stmt.excluded.data) + ... ) + >>> print(stmt) + {opensql}INSERT INTO my_table (data, user_email) + VALUES (%(data)s, %(user_email)s) ON CONFLICT (user_email) + WHERE user_email LIKE %(user_email_1)s DO UPDATE SET data = excluded.data + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument is + used to specify an index directly rather than inferring it. This can be + the name of a UNIQUE constraint, a PRIMARY KEY constraint, or an INDEX: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='my_table_idx_1', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT my_table_idx_1 DO UPDATE SET data = %(param_1)s + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='my_table_pk', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT my_table_pk DO UPDATE SET data = %(param_1)s + {stop} + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument may + also refer to a SQLAlchemy construct representing a constraint, + e.g. :class:`.UniqueConstraint`, :class:`.PrimaryKeyConstraint`, + :class:`.Index`, or :class:`.ExcludeConstraint`. In this use, + if the constraint has a name, it is used directly. Otherwise, if the + constraint is unnamed, then inference will be used, where the expressions + and optional WHERE clause of the constraint will be spelled out in the + construct. This use is especially convenient + to refer to the named or unnamed primary key of a :class:`_schema.Table` + using the + :attr:`_schema.Table.primary_key` attribute: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint=my_table.primary_key, + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +The SET Clause +^^^^^^^^^^^^^^^ + +``ON CONFLICT...DO UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are specified using the +:paramref:`_postgresql.Insert.on_conflict_do_update.set_` parameter. This +parameter accepts a dictionary which consists of direct values +for UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +.. warning:: + + The :meth:`_expression.Insert.on_conflict_do_update` + method does **not** take into + account Python-side default UPDATE values or generation functions, e.g. + those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of UPDATE, + unless they are manually specified in the + :paramref:`_postgresql.Insert.on_conflict_do_update.set_` dictionary. + +Updating using the Excluded INSERT Values +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to refer to the proposed insertion row, the special alias +:attr:`~.postgresql.Insert.excluded` is available as an attribute on +the :class:`_postgresql.Insert` object; this object is a +:class:`_expression.ColumnCollection` +which alias contains all columns of the target +table: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author) + ... ) + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data, author) + VALUES (%(id)s, %(data)s, %(author)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author + +Additional WHERE Criteria +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`_expression.Insert.on_conflict_do_update` method also accepts +a WHERE clause using the :paramref:`_postgresql.Insert.on_conflict_do_update.where` +parameter, which will limit those rows which receive an UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + >>> on_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author), + ... where=(my_table.c.status == 2) + ... ) + >>> print(on_update_stmt) + {opensql}INSERT INTO my_table (id, data, author) + VALUES (%(id)s, %(data)s, %(author)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author + WHERE my_table.status = %(status_1)s + +Skipping Rows with DO NOTHING +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ON CONFLICT`` may be used to skip inserting a row entirely +if any conflict with a unique or exclusion constraint occurs; below +this is illustrated using the +:meth:`~.postgresql.Insert.on_conflict_do_nothing` method: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + >>> print(stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO NOTHING + +If ``DO NOTHING`` is used without specifying any columns or constraint, +it has the effect of skipping the INSERT for any unique or exclusion +constraint violation which occurs: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing() + >>> print(stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT DO NOTHING + +.. _postgresql_match: + +Full Text Search +---------------- + +SQLAlchemy makes available the PostgreSQL ``@@`` operator via the +:meth:`_expression.ColumnElement.match` method on any textual column expression. + +On the PostgreSQL dialect, an expression like the following:: + + select(sometable.c.text.match("search string")) + +will emit to the database:: + + SELECT text @@ to_tsquery('search string') FROM table + +Various other PostgreSQL text search functions such as ``to_tsquery()``, +``to_tsvector()``, and ``plainto_tsquery()`` are available by explicitly using +the standard SQLAlchemy :data:`.func` construct. + +For example:: + + select(func.to_tsvector('fat cats ate rats').match('cat & rat')) + +Emits the equivalent of:: + + SELECT to_tsvector('fat cats ate rats') @@ to_tsquery('cat & rat') + +The :class:`_postgresql.TSVECTOR` type can provide for explicit CAST:: + + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlalchemy import select, cast + select(cast("some text", TSVECTOR)) + +produces a statement equivalent to:: + + SELECT CAST('some text' AS TSVECTOR) AS anon_1 + +.. tip:: + + It's important to remember that text searching in PostgreSQL is powerful but complicated, + and SQLAlchemy users are advised to reference the PostgreSQL documentation + regarding + `Full Text Search <https://www.postgresql.org/docs/current/textsearch-controls.html>`_. + + There are important differences between ``to_tsquery`` and + ``plainto_tsquery``, the most significant of which is that ``to_tsquery`` + expects specially formatted "querytext" that is written to PostgreSQL's own + specification, while ``plainto_tsquery`` expects unformatted text that is + transformed into ``to_tsquery`` compatible querytext. This means the input to + ``.match()`` under PostgreSQL may be incompatible with the input to + ``.match()`` under another database backend. SQLAlchemy users who support + multiple backends are advised to carefully implement their usage of + ``.match()`` to work around these constraints. + +Full Text Searches in PostgreSQL are influenced by a combination of: the +PostgreSQL setting of ``default_text_search_config``, the ``regconfig`` used +to build the GIN/GiST indexes, and the ``regconfig`` optionally passed in +during a query. + +When performing a Full Text Search against a column that has a GIN or +GiST index that is already pre-computed (which is common on full text +searches) one may need to explicitly pass in a particular PostgreSQL +``regconfig`` value to ensure the query-planner utilizes the index and does +not re-compute the column on demand. + +In order to provide for this explicit query planning, or to use different +search strategies, the ``match`` method accepts a ``postgresql_regconfig`` +keyword argument:: + + select(mytable.c.id).where( + mytable.c.title.match('somestring', postgresql_regconfig='english') + ) + +Emits the equivalent of:: + + SELECT mytable.id FROM mytable + WHERE mytable.title @@ to_tsquery('english', 'somestring') + +One can also specifically pass in a `'regconfig'` value to the +``to_tsvector()`` command as the initial argument:: + + select(mytable.c.id).where( + func.to_tsvector('english', mytable.c.title )\ + .match('somestring', postgresql_regconfig='english') + ) + +produces a statement equivalent to:: + + SELECT mytable.id FROM mytable + WHERE to_tsvector('english', mytable.title) @@ + to_tsquery('english', 'somestring') + +It is recommended that you use the ``EXPLAIN ANALYZE...`` tool from +PostgreSQL to ensure that you are generating queries with SQLAlchemy that +take full advantage of any indexes you may have created for full text search. + +.. seealso:: + + `Full Text Search <https://www.postgresql.org/docs/current/textsearch-controls.html>`_ - in the PostgreSQL documentation + + +FROM ONLY ... +------------- + +The dialect supports PostgreSQL's ONLY keyword for targeting only a particular +table in an inheritance hierarchy. This can be used to produce the +``SELECT ... FROM ONLY``, ``UPDATE ONLY ...``, and ``DELETE FROM ONLY ...`` +syntaxes. It uses SQLAlchemy's hints mechanism:: + + # SELECT ... FROM ONLY ... + result = table.select().with_hint(table, 'ONLY', 'postgresql') + print(result.fetchall()) + + # UPDATE ONLY ... + table.update(values=dict(foo='bar')).with_hint('ONLY', + dialect_name='postgresql') + + # DELETE FROM ONLY ... + table.delete().with_hint('ONLY', dialect_name='postgresql') + + +.. _postgresql_indexes: + +PostgreSQL-Specific Index Options +--------------------------------- + +Several extensions to the :class:`.Index` construct are available, specific +to the PostgreSQL dialect. + +Covering Indexes +^^^^^^^^^^^^^^^^ + +The ``postgresql_include`` option renders INCLUDE(colname) for the given +string names:: + + Index("my_index", table.c.x, postgresql_include=['y']) + +would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` + +Note that this feature requires PostgreSQL 11 or later. + +.. versionadded:: 1.4 + +.. _postgresql_partial_indexes: + +Partial Indexes +^^^^^^^^^^^^^^^ + +Partial indexes add criterion to the index definition so that the index is +applied to a subset of rows. These can be specified on :class:`.Index` +using the ``postgresql_where`` keyword argument:: + + Index('my_index', my_table.c.id, postgresql_where=my_table.c.value > 10) + +.. _postgresql_operator_classes: + +Operator Classes +^^^^^^^^^^^^^^^^ + +PostgreSQL allows the specification of an *operator class* for each column of +an index (see +https://www.postgresql.org/docs/current/interactive/indexes-opclass.html). +The :class:`.Index` construct allows these to be specified via the +``postgresql_ops`` keyword argument:: + + Index( + 'my_index', my_table.c.id, my_table.c.data, + postgresql_ops={ + 'data': 'text_pattern_ops', + 'id': 'int4_ops' + }) + +Note that the keys in the ``postgresql_ops`` dictionaries are the +"key" name of the :class:`_schema.Column`, i.e. the name used to access it from +the ``.c`` collection of :class:`_schema.Table`, which can be configured to be +different than the actual name of the column as expressed in the database. + +If ``postgresql_ops`` is to be used against a complex SQL expression such +as a function call, then to apply to the column it must be given a label +that is identified in the dictionary by name, e.g.:: + + Index( + 'my_index', my_table.c.id, + func.lower(my_table.c.data).label('data_lower'), + postgresql_ops={ + 'data_lower': 'text_pattern_ops', + 'id': 'int4_ops' + }) + +Operator classes are also supported by the +:class:`_postgresql.ExcludeConstraint` construct using the +:paramref:`_postgresql.ExcludeConstraint.ops` parameter. See that parameter for +details. + +.. versionadded:: 1.3.21 added support for operator classes with + :class:`_postgresql.ExcludeConstraint`. + + +Index Types +^^^^^^^^^^^ + +PostgreSQL provides several index types: B-Tree, Hash, GiST, and GIN, as well +as the ability for users to create their own (see +https://www.postgresql.org/docs/current/static/indexes-types.html). These can be +specified on :class:`.Index` using the ``postgresql_using`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_using='gin') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX command, so it *must* be a valid index type for your +version of PostgreSQL. + +.. _postgresql_index_storage: + +Index Storage Parameters +^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL allows storage parameters to be set on indexes. The storage +parameters available depend on the index method used by the index. Storage +parameters can be specified on :class:`.Index` using the ``postgresql_with`` +keyword argument:: + + Index('my_index', my_table.c.data, postgresql_with={"fillfactor": 50}) + +.. versionadded:: 1.0.6 + +PostgreSQL allows to define the tablespace in which to create the index. +The tablespace can be specified on :class:`.Index` using the +``postgresql_tablespace`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_tablespace='my_tablespace') + +.. versionadded:: 1.1 + +Note that the same option is available on :class:`_schema.Table` as well. + +.. _postgresql_index_concurrently: + +Indexes with CONCURRENTLY +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The PostgreSQL index option CONCURRENTLY is supported by passing the +flag ``postgresql_concurrently`` to the :class:`.Index` construct:: + + tbl = Table('testtbl', m, Column('data', Integer)) + + idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) + +The above index construct will render DDL for CREATE INDEX, assuming +PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as:: + + CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data) + +For DROP INDEX, assuming PostgreSQL 9.2 or higher is detected or for +a connection-less dialect, it will emit:: + + DROP INDEX CONCURRENTLY test_idx1 + +.. versionadded:: 1.1 support for CONCURRENTLY on DROP INDEX. The + CONCURRENTLY keyword is now only emitted if a high enough version + of PostgreSQL is detected on the connection (or for a connection-less + dialect). + +When using CONCURRENTLY, the PostgreSQL database requires that the statement +be invoked outside of a transaction block. The Python DBAPI enforces that +even for a single statement, a transaction is present, so to use this +construct, the DBAPI's "autocommit" mode must be used:: + + metadata = MetaData() + table = Table( + "foo", metadata, + Column("id", String)) + index = Index( + "foo_idx", table.c.id, postgresql_concurrently=True) + + with engine.connect() as conn: + with conn.execution_options(isolation_level='AUTOCOMMIT'): + table.create(conn) + +.. seealso:: + + :ref:`postgresql_isolation_level` + +.. _postgresql_index_reflection: + +PostgreSQL Index Reflection +--------------------------- + +The PostgreSQL database creates a UNIQUE INDEX implicitly whenever the +UNIQUE CONSTRAINT construct is used. When inspecting a table using +:class:`_reflection.Inspector`, the :meth:`_reflection.Inspector.get_indexes` +and the :meth:`_reflection.Inspector.get_unique_constraints` +will report on these +two constructs distinctly; in the case of the index, the key +``duplicates_constraint`` will be present in the index entry if it is +detected as mirroring a constraint. When performing reflection using +``Table(..., autoload_with=engine)``, the UNIQUE INDEX is **not** returned +in :attr:`_schema.Table.indexes` when it is detected as mirroring a +:class:`.UniqueConstraint` in the :attr:`_schema.Table.constraints` collection +. + +.. versionchanged:: 1.0.0 - :class:`_schema.Table` reflection now includes + :class:`.UniqueConstraint` objects present in the + :attr:`_schema.Table.constraints` + collection; the PostgreSQL backend will no longer include a "mirrored" + :class:`.Index` construct in :attr:`_schema.Table.indexes` + if it is detected + as corresponding to a unique constraint. + +Special Reflection Options +-------------------------- + +The :class:`_reflection.Inspector` +used for the PostgreSQL backend is an instance +of :class:`.PGInspector`, which offers additional methods:: + + from sqlalchemy import create_engine, inspect + + engine = create_engine("postgresql+psycopg2://localhost/test") + insp = inspect(engine) # will be a PGInspector + + print(insp.get_enums()) + +.. autoclass:: PGInspector + :members: + +.. _postgresql_table_options: + +PostgreSQL Table Options +------------------------ + +Several options for CREATE TABLE are supported directly by the PostgreSQL +dialect in conjunction with the :class:`_schema.Table` construct: + +* ``TABLESPACE``:: + + Table("some_table", metadata, ..., postgresql_tablespace='some_tablespace') + + The above option is also available on the :class:`.Index` construct. + +* ``ON COMMIT``:: + + Table("some_table", metadata, ..., postgresql_on_commit='PRESERVE ROWS') + +* ``WITH OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=True) + +* ``WITHOUT OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=False) + +* ``INHERITS``:: + + Table("some_table", metadata, ..., postgresql_inherits="some_supertable") + + Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...)) + + .. versionadded:: 1.0.0 + +* ``PARTITION BY``:: + + Table("some_table", metadata, ..., + postgresql_partition_by='LIST (part_column)') + + .. versionadded:: 1.2.6 + +.. seealso:: + + `PostgreSQL CREATE TABLE options + <https://www.postgresql.org/docs/current/static/sql-createtable.html>`_ - + in the PostgreSQL documentation. + +.. _postgresql_constraint_options: + +PostgreSQL Constraint Options +----------------------------- + +The following option(s) are supported by the PostgreSQL dialect in conjunction +with selected constraint constructs: + +* ``NOT VALID``: This option applies towards CHECK and FOREIGN KEY constraints + when the constraint is being added to an existing table via ALTER TABLE, + and has the effect that existing rows are not scanned during the ALTER + operation against the constraint being added. + + When using a SQL migration tool such as `Alembic <https://alembic.sqlalchemy.org>`_ + that renders ALTER TABLE constructs, the ``postgresql_not_valid`` argument + may be specified as an additional keyword argument within the operation + that creates the constraint, as in the following Alembic example:: + + def update(): + op.create_foreign_key( + "fk_user_address", + "address", + "user", + ["user_id"], + ["id"], + postgresql_not_valid=True + ) + + The keyword is ultimately accepted directly by the + :class:`_schema.CheckConstraint`, :class:`_schema.ForeignKeyConstraint` + and :class:`_schema.ForeignKey` constructs; when using a tool like + Alembic, dialect-specific keyword arguments are passed through to + these constructs from the migration operation directives:: + + CheckConstraint("some_field IS NOT NULL", postgresql_not_valid=True) + + ForeignKeyConstraint(["some_id"], ["some_table.some_id"], postgresql_not_valid=True) + + .. versionadded:: 1.4.32 + + .. seealso:: + + `PostgreSQL ALTER TABLE options + <https://www.postgresql.org/docs/current/static/sql-altertable.html>`_ - + in the PostgreSQL documentation. + +.. _postgresql_table_valued_overview: + +Table values, Table and Column valued functions, Row and Tuple objects +----------------------------------------------------------------------- + +PostgreSQL makes great use of modern SQL forms such as table-valued functions, +tables and rows as values. These constructs are commonly used as part +of PostgreSQL's support for complex datatypes such as JSON, ARRAY, and other +datatypes. SQLAlchemy's SQL expression language has native support for +most table-valued and row-valued forms. + +.. _postgresql_table_valued: + +Table-Valued Functions +^^^^^^^^^^^^^^^^^^^^^^^ + +Many PostgreSQL built-in functions are intended to be used in the FROM clause +of a SELECT statement, and are capable of returning table rows or sets of table +rows. A large portion of PostgreSQL's JSON functions for example such as +``json_array_elements()``, ``json_object_keys()``, ``json_each_text()``, +``json_each()``, ``json_to_record()``, ``json_populate_recordset()`` use such +forms. These classes of SQL function calling forms in SQLAlchemy are available +using the :meth:`_functions.FunctionElement.table_valued` method in conjunction +with :class:`_functions.Function` objects generated from the :data:`_sql.func` +namespace. + +Examples from PostgreSQL's reference documentation follow below: + +* ``json_each()``:: + + >>> from sqlalchemy import select, func + >>> stmt = select(func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value")) + >>> print(stmt) + SELECT anon_1.key, anon_1.value + FROM json_each(:json_each_1) AS anon_1 + +* ``json_populate_record()``:: + + >>> from sqlalchemy import select, func, literal_column + >>> stmt = select( + ... func.json_populate_record( + ... literal_column("null::myrowtype"), + ... '{"a":1,"b":2}' + ... ).table_valued("a", "b", name="x") + ... ) + >>> print(stmt) + SELECT x.a, x.b + FROM json_populate_record(null::myrowtype, :json_populate_record_1) AS x + +* ``json_to_record()`` - this form uses a PostgreSQL specific form of derived + columns in the alias, where we may make use of :func:`_sql.column` elements with + types to produce them. The :meth:`_functions.FunctionElement.table_valued` + method produces a :class:`_sql.TableValuedAlias` construct, and the method + :meth:`_sql.TableValuedAlias.render_derived` method sets up the derived + columns specification:: + + >>> from sqlalchemy import select, func, column, Integer, Text + >>> stmt = select( + ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}').table_valued( + ... column("a", Integer), column("b", Text), column("d", Text), + ... ).render_derived(name="x", with_types=True) + ... ) + >>> print(stmt) + SELECT x.a, x.b, x.d + FROM json_to_record(:json_to_record_1) AS x(a INTEGER, b TEXT, d TEXT) + +* ``WITH ORDINALITY`` - part of the SQL standard, ``WITH ORDINALITY`` adds an + ordinal counter to the output of a function and is accepted by a limited set + of PostgreSQL functions including ``unnest()`` and ``generate_series()``. The + :meth:`_functions.FunctionElement.table_valued` method accepts a keyword + parameter ``with_ordinality`` for this purpose, which accepts the string name + that will be applied to the "ordinality" column:: + + >>> from sqlalchemy import select, func + >>> stmt = select( + ... func.generate_series(4, 1, -1). + ... table_valued("value", with_ordinality="ordinality"). + ... render_derived() + ... ) + >>> print(stmt) + SELECT anon_1.value, anon_1.ordinality + FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3) + WITH ORDINALITY AS anon_1(value, ordinality) + +.. versionadded:: 1.4.0b2 + +.. seealso:: + + :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial` + +.. _postgresql_column_valued: + +Column Valued Functions +^^^^^^^^^^^^^^^^^^^^^^^ + +Similar to the table valued function, a column valued function is present +in the FROM clause, but delivers itself to the columns clause as a single +scalar value. PostgreSQL functions such as ``json_array_elements()``, +``unnest()`` and ``generate_series()`` may use this form. Column valued functions are available using the +:meth:`_functions.FunctionElement.column_valued` method of :class:`_functions.FunctionElement`: + +* ``json_array_elements()``:: + + >>> from sqlalchemy import select, func + >>> stmt = select(func.json_array_elements('["one", "two"]').column_valued("x")) + >>> print(stmt) + SELECT x + FROM json_array_elements(:json_array_elements_1) AS x + +* ``unnest()`` - in order to generate a PostgreSQL ARRAY literal, the + :func:`_postgresql.array` construct may be used:: + + + >>> from sqlalchemy.dialects.postgresql import array + >>> from sqlalchemy import select, func + >>> stmt = select(func.unnest(array([1, 2])).column_valued()) + >>> print(stmt) + SELECT anon_1 + FROM unnest(ARRAY[%(param_1)s, %(param_2)s]) AS anon_1 + + The function can of course be used against an existing table-bound column + that's of type :class:`_types.ARRAY`:: + + >>> from sqlalchemy import table, column, ARRAY, Integer + >>> from sqlalchemy import select, func + >>> t = table("t", column('value', ARRAY(Integer))) + >>> stmt = select(func.unnest(t.c.value).column_valued("unnested_value")) + >>> print(stmt) + SELECT unnested_value + FROM unnest(t.value) AS unnested_value + +.. seealso:: + + :ref:`tutorial_functions_column_valued` - in the :ref:`unified_tutorial` + + +Row Types +^^^^^^^^^ + +Built-in support for rendering a ``ROW`` may be approximated using +``func.ROW`` with the :attr:`_sa.func` namespace, or by using the +:func:`_sql.tuple_` construct:: + + >>> from sqlalchemy import table, column, func, tuple_ + >>> t = table("t", column("id"), column("fk")) + >>> stmt = t.select().where( + ... tuple_(t.c.id, t.c.fk) > (1,2) + ... ).where( + ... func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7) + ... ) + >>> print(stmt) + SELECT t.id, t.fk + FROM t + WHERE (t.id, t.fk) > (:param_1, :param_2) AND ROW(t.id, t.fk) < ROW(:ROW_1, :ROW_2) + +.. seealso:: + + `PostgreSQL Row Constructors + <https://www.postgresql.org/docs/current/sql-expressions.html#SQL-SYNTAX-ROW-CONSTRUCTORS>`_ + + `PostgreSQL Row Constructor Comparison + <https://www.postgresql.org/docs/current/functions-comparisons.html#ROW-WISE-COMPARISON>`_ + +Table Types passed to Functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL supports passing a table as an argument to a function, which it +refers towards as a "record" type. SQLAlchemy :class:`_sql.FromClause` objects +such as :class:`_schema.Table` support this special form using the +:meth:`_sql.FromClause.table_valued` method, which is comparable to the +:meth:`_functions.FunctionElement.table_valued` method except that the collection +of columns is already established by that of the :class:`_sql.FromClause` +itself:: + + + >>> from sqlalchemy import table, column, func, select + >>> a = table( "a", column("id"), column("x"), column("y")) + >>> stmt = select(func.row_to_json(a.table_valued())) + >>> print(stmt) + SELECT row_to_json(a) AS row_to_json_1 + FROM a + +.. versionadded:: 1.4.0b2 + + +ARRAY Types +----------- + +The PostgreSQL dialect supports arrays, both as multidimensional column types +as well as array literals: + +* :class:`_postgresql.ARRAY` - ARRAY datatype + +* :class:`_postgresql.array` - array literal + +* :func:`_postgresql.array_agg` - ARRAY_AGG SQL function + +* :class:`_postgresql.aggregate_order_by` - helper for PG's ORDER BY aggregate + function syntax. + +JSON Types +---------- + +The PostgreSQL dialect supports both JSON and JSONB datatypes, including +psycopg2's native support and support for all of PostgreSQL's special +operators: + +* :class:`_postgresql.JSON` + +* :class:`_postgresql.JSONB` + +HSTORE Type +----------- + +The PostgreSQL HSTORE type as well as hstore literals are supported: + +* :class:`_postgresql.HSTORE` - HSTORE datatype + +* :class:`_postgresql.hstore` - hstore literal + +ENUM Types +---------- + +PostgreSQL has an independently creatable TYPE structure which is used +to implement an enumerated type. This approach introduces significant +complexity on the SQLAlchemy side in terms of when this type should be +CREATED and DROPPED. The type object is also an independently reflectable +entity. The following sections should be consulted: + +* :class:`_postgresql.ENUM` - DDL and typing support for ENUM. + +* :meth:`.PGInspector.get_enums` - retrieve a listing of current ENUM types + +* :meth:`.postgresql.ENUM.create` , :meth:`.postgresql.ENUM.drop` - individual + CREATE and DROP commands for ENUM. + +.. _postgresql_array_of_enum: + +Using ENUM with ARRAY +^^^^^^^^^^^^^^^^^^^^^ + +The combination of ENUM and ARRAY is not directly supported by backend +DBAPIs at this time. Prior to SQLAlchemy 1.3.17, a special workaround +was needed in order to allow this combination to work, described below. + +.. versionchanged:: 1.3.17 The combination of ENUM and ARRAY is now directly + handled by SQLAlchemy's implementation without any workarounds needed. + +.. sourcecode:: python + + from sqlalchemy import TypeDecorator + from sqlalchemy.dialects.postgresql import ARRAY + + class ArrayOfEnum(TypeDecorator): + impl = ARRAY + + def bind_expression(self, bindvalue): + return sa.cast(bindvalue, self) + + def result_processor(self, dialect, coltype): + super_rp = super(ArrayOfEnum, self).result_processor( + dialect, coltype) + + def handle_raw_string(value): + inner = re.match(r"^{(.*)}$", value).group(1) + return inner.split(",") if inner else [] + + def process(value): + if value is None: + return None + return super_rp(handle_raw_string(value)) + return process + +E.g.:: + + Table( + 'mydata', metadata, + Column('id', Integer, primary_key=True), + Column('data', ArrayOfEnum(ENUM('a', 'b, 'c', name='myenum'))) + + ) + +This type is not included as a built-in type as it would be incompatible +with a DBAPI that suddenly decides to support ARRAY of ENUM directly in +a new version. + +.. _postgresql_array_of_json: + +Using JSON/JSONB with ARRAY +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similar to using ENUM, prior to SQLAlchemy 1.3.17, for an ARRAY of JSON/JSONB +we need to render the appropriate CAST. Current psycopg2 drivers accommodate +the result set correctly without any special steps. + +.. versionchanged:: 1.3.17 The combination of JSON/JSONB and ARRAY is now + directly handled by SQLAlchemy's implementation without any workarounds + needed. + +.. sourcecode:: python + + class CastingArray(ARRAY): + def bind_expression(self, bindvalue): + return sa.cast(bindvalue, self) + +E.g.:: + + Table( + 'mydata', metadata, + Column('id', Integer, primary_key=True), + Column('data', CastingArray(JSONB)) + ) + + +""" # noqa: E501 + +from collections import defaultdict +import datetime as dt +import re +from uuid import UUID as _python_UUID + +from . import array as _array +from . import dml +from . import hstore as _hstore +from . import json as _json +from . import ranges as _ranges +from ... import exc +from ... import schema +from ... import sql +from ... import util +from ...engine import characteristics +from ...engine import default +from ...engine import reflection +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import expression +from ...sql import roles +from ...sql import sqltypes +from ...sql import util as sql_util +from ...sql.ddl import DDLBase +from ...types import BIGINT +from ...types import BOOLEAN +from ...types import CHAR +from ...types import DATE +from ...types import FLOAT +from ...types import INTEGER +from ...types import NUMERIC +from ...types import REAL +from ...types import SMALLINT +from ...types import TEXT +from ...types import VARCHAR + +IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I) + +AUTOCOMMIT_REGEXP = re.compile( + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|" + "IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)", + re.I | re.UNICODE, +) + +RESERVED_WORDS = set( + [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "leading", + "limit", + "localtime", + "localtimestamp", + "new", + "not", + "null", + "of", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", + "authorization", + "between", + "binary", + "cross", + "current_schema", + "freeze", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "notnull", + "outer", + "over", + "overlaps", + "right", + "similar", + "verbose", + ] +) + +_DECIMAL_TYPES = (1231, 1700) +_FLOAT_TYPES = (700, 701, 1021, 1022) +_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + + +class BYTEA(sqltypes.LargeBinary): + __visit_name__ = "BYTEA" + + +class DOUBLE_PRECISION(sqltypes.Float): + __visit_name__ = "DOUBLE_PRECISION" + + +class INET(sqltypes.TypeEngine): + __visit_name__ = "INET" + + +PGInet = INET + + +class CIDR(sqltypes.TypeEngine): + __visit_name__ = "CIDR" + + +PGCidr = CIDR + + +class MACADDR(sqltypes.TypeEngine): + __visit_name__ = "MACADDR" + + +PGMacAddr = MACADDR + + +class MONEY(sqltypes.TypeEngine): + + r"""Provide the PostgreSQL MONEY type. + + Depending on driver, result rows using this type may return a + string value which includes currency symbols. + + For this reason, it may be preferable to provide conversion to a + numerically-based currency datatype using :class:`_types.TypeDecorator`:: + + import re + import decimal + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def process_result_value(self, value: Any, dialect: Any) -> None: + if value is not None: + # adjust this for the currency and numeric + m = re.match(r"\$([\d.]+)", value) + if m: + value = decimal.Decimal(m.group(1)) + return value + + Alternatively, the conversion may be applied as a CAST using + the :meth:`_types.TypeDecorator.column_expression` method as follows:: + + import decimal + from sqlalchemy import cast + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def column_expression(self, column: Any): + return cast(column, Numeric()) + + .. versionadded:: 1.2 + + """ + + __visit_name__ = "MONEY" + + +class OID(sqltypes.TypeEngine): + + """Provide the PostgreSQL OID type. + + .. versionadded:: 0.9.5 + + """ + + __visit_name__ = "OID" + + +class REGCLASS(sqltypes.TypeEngine): + + """Provide the PostgreSQL REGCLASS type. + + .. versionadded:: 1.2.7 + + """ + + __visit_name__ = "REGCLASS" + + +class TIMESTAMP(sqltypes.TIMESTAMP): + + """Provide the PostgreSQL TIMESTAMP type.""" + + __visit_name__ = "TIMESTAMP" + + def __init__(self, timezone=False, precision=None): + """Construct a TIMESTAMP. + + :param timezone: boolean value if timezone present, default False + :param precision: optional integer precision value + + .. versionadded:: 1.4 + + """ + super(TIMESTAMP, self).__init__(timezone=timezone) + self.precision = precision + + +class TIME(sqltypes.TIME): + + """PostgreSQL TIME type.""" + + __visit_name__ = "TIME" + + def __init__(self, timezone=False, precision=None): + """Construct a TIME. + + :param timezone: boolean value if timezone present, default False + :param precision: optional integer precision value + + .. versionadded:: 1.4 + + """ + super(TIME, self).__init__(timezone=timezone) + self.precision = precision + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + + """PostgreSQL INTERVAL type.""" + + __visit_name__ = "INTERVAL" + native = True + + def __init__(self, precision=None, fields=None): + """Construct an INTERVAL. + + :param precision: optional integer precision value + :param fields: string fields specifier. allows storage of fields + to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, + etc. + + .. versionadded:: 1.2 + + """ + self.precision = precision + self.fields = fields + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return INTERVAL(precision=interval.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval(native=True, second_precision=self.precision) + + @property + def python_type(self): + return dt.timedelta + + def coerce_compared_value(self, op, value): + return self + + +PGInterval = INTERVAL + + +class BIT(sqltypes.TypeEngine): + __visit_name__ = "BIT" + + def __init__(self, length=None, varying=False): + if not varying: + # BIT without VARYING defaults to length 1 + self.length = length or 1 + else: + # but BIT VARYING can be unlimited-length, so no default + self.length = length + self.varying = varying + + +PGBit = BIT + + +class UUID(sqltypes.TypeEngine): + + """PostgreSQL UUID type. + + Represents the UUID column type, interpreting + data either as natively returned by the DBAPI + or as Python uuid objects. + + The UUID type is currently known to work within the prominent DBAPI + drivers supported by SQLAlchemy including psycopg2, pg8000 and + asyncpg. Support for other DBAPI drivers may be incomplete or non-present. + + """ + + __visit_name__ = "UUID" + + def __init__(self, as_uuid=False): + """Construct a UUID type. + + + :param as_uuid=False: if True, values will be interpreted + as Python uuid objects, converting to/from string via the + DBAPI. + + """ + self.as_uuid = as_uuid + + def coerce_compared_value(self, op, value): + """See :meth:`.TypeEngine.coerce_compared_value` for a description.""" + + if isinstance(value, util.string_types): + return self + else: + return super(UUID, self).coerce_compared_value(op, value) + + def bind_processor(self, dialect): + if self.as_uuid: + + def process(value): + if value is not None: + value = util.text_type(value) + return value + + return process + else: + return None + + def result_processor(self, dialect, coltype): + if self.as_uuid: + + def process(value): + if value is not None: + value = _python_UUID(value) + return value + + return process + else: + return None + + def literal_processor(self, dialect): + if self.as_uuid: + + def process(value): + if value is not None: + value = "'%s'::UUID" % value + return value + + return process + else: + + def process(value): + if value is not None: + value = "'%s'" % value + return value + + return process + + @property + def python_type(self): + return _python_UUID if self.as_uuid else str + + +PGUuid = UUID + + +class TSVECTOR(sqltypes.TypeEngine): + + """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`postgresql_match` + + """ + + __visit_name__ = "TSVECTOR" + + +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): + + """PostgreSQL ENUM type. + + This is a subclass of :class:`_types.Enum` which includes + support for PG's ``CREATE TYPE`` and ``DROP TYPE``. + + When the builtin type :class:`_types.Enum` is used and the + :paramref:`.Enum.native_enum` flag is left at its default of + True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` + type as the implementation, so the special create/drop rules + will be used. + + The create/drop behavior of ENUM is necessarily intricate, due to the + awkward relationship the ENUM type has in relationship to the + parent table, in that it may be "owned" by just a single table, or + may be shared among many tables. + + When using :class:`_types.Enum` or :class:`_postgresql.ENUM` + in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted + corresponding to when the :meth:`_schema.Table.create` and + :meth:`_schema.Table.drop` + methods are called:: + + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + ) + + table.create(engine) # will emit CREATE ENUM and CREATE TABLE + table.drop(engine) # will emit DROP TABLE and DROP ENUM + + To use a common enumerated type between multiple tables, the best + practice is to declare the :class:`_types.Enum` or + :class:`_postgresql.ENUM` independently, and associate it with the + :class:`_schema.MetaData` object itself:: + + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) + + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) + + When this pattern is used, care must still be taken at the level + of individual table creates. Emitting CREATE TABLE without also + specifying ``checkfirst=True`` will still cause issues:: + + t1.create(engine) # will fail: no such type 'myenum' + + If we specify ``checkfirst=True``, the individual table-level create + operation will check for the ``ENUM`` and create if not exists:: + + # will check if enum exists, and emit CREATE TYPE if not + t1.create(engine, checkfirst=True) + + When using a metadata-level ENUM type, the type will always be created + and dropped if either the metadata-wide create/drop is called:: + + metadata.create_all(engine) # will emit CREATE TYPE + metadata.drop_all(engine) # will emit DROP TYPE + + The type can also be created and dropped directly:: + + my_enum.create(engine) + my_enum.drop(engine) + + .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type + now behaves more strictly with regards to CREATE/DROP. A metadata-level + ENUM type will only be created and dropped at the metadata level, + not the table level, with the exception of + ``table.create(checkfirst=True)``. + The ``table.drop()`` call will now emit a DROP TYPE for a table-level + enumerated type. + + """ + + native_enum = True + + def __init__(self, *enums, **kw): + """Construct an :class:`_postgresql.ENUM`. + + Arguments are the same as that of + :class:`_types.Enum`, but also including + the following parameters. + + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be + emitted, after optionally checking for the + presence of the type, when the parent + table is being created; and additionally + that ``DROP TYPE`` is called when the table + is dropped. When ``False``, no check + will be performed and no ``CREATE TYPE`` + or ``DROP TYPE`` is emitted, unless + :meth:`~.postgresql.ENUM.create` + or :meth:`~.postgresql.ENUM.drop` + are called directly. + Setting to ``False`` is helpful + when invoking a creation scheme to a SQL file + without access to the actual database - + the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods can + be used to emit SQL to a target bind. + + """ + native_enum = kw.pop("native_enum", None) + if native_enum is False: + util.warn( + "the native_enum flag does not apply to the " + "sqlalchemy.dialects.postgresql.ENUM datatype; this type " + "always refers to ENUM. Use sqlalchemy.types.Enum for " + "non-native enum." + ) + self.create_type = kw.pop("create_type", True) + super(ENUM, self).__init__(*enums, **kw) + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("name", impl.name) + kw.setdefault("schema", impl.schema) + kw.setdefault("inherit_schema", impl.inherit_schema) + kw.setdefault("metadata", impl.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + return cls(**kw) + + def create(self, bind=None, checkfirst=True): + """Emit ``CREATE TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL CREATE TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + if not bind.dialect.supports_native_enum: + return + + bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True): + """Emit ``DROP TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL DROP TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + if not bind.dialect.supports_native_enum: + return + + bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) + + class EnumGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return not self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_create_enum(enum): + return + + self.connection.execute(CreateEnumType(enum)) + + class EnumDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_drop_enum(enum): + return + + self.connection.execute(DropEnumType(enum)) + + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. + + This to ensure a particular named enum is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if not self.create_type: + return True + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + if "_pg_enums" in ddl_runner.memo: + pg_enums = ddl_runner.memo["_pg_enums"] + else: + pg_enums = ddl_runner.memo["_pg_enums"] = set() + present = (self.schema, self.name) in pg_enums + pg_enums.add((self.schema, self.name)) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + ) and not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_table_drop(self, target, bind, checkfirst=False, **kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): + self.drop(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + + +class _ColonCast(elements.Cast): + __visit_name__ = "colon_cast" + + def __init__(self, expression, type_): + self.type = type_ + self.clause = expression + self.typeclause = elements.TypeClause(type_) + + +colspecs = { + sqltypes.ARRAY: _array.ARRAY, + sqltypes.Interval: INTERVAL, + sqltypes.Enum: ENUM, + sqltypes.JSON.JSONPathType: _json.JSONPathType, + sqltypes.JSON: _json.JSON, +} + +ischema_names = { + "_array": _array.ARRAY, + "hstore": _hstore.HSTORE, + "json": _json.JSON, + "jsonb": _json.JSONB, + "int4range": _ranges.INT4RANGE, + "int8range": _ranges.INT8RANGE, + "numrange": _ranges.NUMRANGE, + "daterange": _ranges.DATERANGE, + "tsrange": _ranges.TSRANGE, + "tstzrange": _ranges.TSTZRANGE, + "integer": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "character varying": VARCHAR, + "character": CHAR, + '"char"': sqltypes.String, + "name": sqltypes.String, + "text": TEXT, + "numeric": NUMERIC, + "float": FLOAT, + "real": REAL, + "inet": INET, + "cidr": CIDR, + "uuid": UUID, + "bit": BIT, + "bit varying": BIT, + "macaddr": MACADDR, + "money": MONEY, + "oid": OID, + "regclass": REGCLASS, + "double precision": DOUBLE_PRECISION, + "timestamp": TIMESTAMP, + "timestamp with time zone": TIMESTAMP, + "timestamp without time zone": TIMESTAMP, + "time with time zone": TIME, + "time without time zone": TIME, + "date": DATE, + "time": TIME, + "bytea": BYTEA, + "boolean": BOOLEAN, + "interval": INTERVAL, + "tsvector": TSVECTOR, +} + + +class PGCompiler(compiler.SQLCompiler): + def visit_colon_cast(self, element, **kw): + return "%s::%s" % ( + element.clause._compiler_dispatch(self, **kw), + element.typeclause._compiler_dispatch(self, **kw), + ) + + def visit_array(self, element, **kw): + return "ARRAY[%s]" % self.visit_clauselist(element, **kw) + + def visit_slice(self, element, **kw): + return "%s:%s" % ( + self.process(element.start, **kw), + self.process(element.stop, **kw), + ) + + def visit_json_getitem_op_binary( + self, binary, operator, _cast_applied=False, **kw + ): + if ( + not _cast_applied + and binary.type._type_affinity is not sqltypes.JSON + ): + kw["_cast_applied"] = True + return self.process(sql.cast(binary, binary.type), **kw) + + kw["eager_grouping"] = True + + return self._generate_generic_binary( + binary, " -> " if not _cast_applied else " ->> ", **kw + ) + + def visit_json_path_getitem_op_binary( + self, binary, operator, _cast_applied=False, **kw + ): + if ( + not _cast_applied + and binary.type._type_affinity is not sqltypes.JSON + ): + kw["_cast_applied"] = True + return self.process(sql.cast(binary, binary.type), **kw) + + kw["eager_grouping"] = True + return self._generate_generic_binary( + binary, " #> " if not _cast_applied else " #>> ", **kw + ) + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_aggregate_order_by(self, element, **kw): + return "%s ORDER BY %s" % ( + self.process(element.target, **kw), + self.process(element.order_by, **kw), + ) + + def visit_match_op_binary(self, binary, operator, **kw): + if "postgresql_regconfig" in binary.modifiers: + regconfig = self.render_literal_value( + binary.modifiers["postgresql_regconfig"], sqltypes.STRINGTYPE + ) + if regconfig: + return "%s @@ to_tsquery(%s, %s)" % ( + self.process(binary.left, **kw), + regconfig, + self.process(binary.right, **kw), + ) + return "%s @@ to_tsquery(%s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + + return "%s ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def visit_not_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return "%s NOT ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) + + def _regexp_match(self, base_op, binary, operator, kw): + flags = binary.modifiers["flags"] + if flags is None: + return self._generate_generic_binary( + binary, " %s " % base_op, **kw + ) + if isinstance(flags, elements.BindParameter) and flags.value == "i": + return self._generate_generic_binary( + binary, " %s* " % base_op, **kw + ) + flags = self.process(flags, **kw) + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) + return "%s %s CONCAT('(?', %s, ')', %s)" % ( + string, + base_op, + flags, + pattern, + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match("~", binary, operator, kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match("!~", binary, operator, kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) + flags = binary.modifiers["flags"] + if flags is not None: + flags = self.process(flags, **kw) + replacement = self.process(binary.modifiers["replacement"], **kw) + if flags is None: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + string, + pattern, + replacement, + ) + else: + return "REGEXP_REPLACE(%s, %s, %s, %s)" % ( + string, + pattern, + replacement, + flags, + ) + + def visit_empty_set_expr(self, element_types): + # cast the empty set to the type we are comparing against. if + # we are comparing against the null type, pick an arbitrary + # datatype for the empty set + return "SELECT %s WHERE 1!=1" % ( + ", ".join( + "CAST(NULL AS %s)" + % self.dialect.type_compiler.process( + INTEGER() if type_._isnull else type_ + ) + for type_ in element_types or [INTEGER()] + ), + ) + + def render_literal_value(self, value, type_): + value = super(PGCompiler, self).render_literal_value(value, type_) + + if self.dialect._backslash_escapes: + value = value.replace("\\", "\\\\") + return value + + def visit_sequence(self, seq, **kw): + return "nextval('%s')" % self.preparer.format_sequence(seq) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += " \n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT ALL" + text += " OFFSET " + self.process(select._offset_clause, **kw) + return text + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + if hint.upper() != "ONLY": + raise exc.CompileError("Unrecognized hint: %r" % hint) + return "ONLY " + sqltext + + def get_select_precolumns(self, select, **kw): + # Do not call super().get_select_precolumns because + # it will warn/raise when distinct on is present + if select._distinct or select._distinct_on: + if select._distinct_on: + return ( + "DISTINCT ON (" + + ", ".join( + [ + self.process(col, **kw) + for col in select._distinct_on + ] + ) + + ") " + ) + else: + return "DISTINCT " + else: + return "" + + def for_update_clause(self, select, **kw): + + if select._for_update_arg.read: + if select._for_update_arg.key_share: + tmp = " FOR KEY SHARE" + else: + tmp = " FOR SHARE" + elif select._for_update_arg.key_share: + tmp = " FOR NO KEY UPDATE" + else: + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + + tables = util.OrderedSet() + for c in select._for_update_arg.of: + tables.update(sql_util.surface_selectables_only(c)) + + tmp += " OF " + ", ".join( + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def returning_clause(self, stmt, returning_cols): + + columns = [ + self._label_returning_column(stmt, c) + for c in expression._select_iterables(returning_cols) + ] + + return "RETURNING " + ", ".join(columns) + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0], **kw) + start = self.process(func.clauses.clauses[1], **kw) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2], **kw) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def _on_conflict_target(self, clause, **kw): + + if clause.constraint_target is not None: + # target may be a name of an Index, UniqueConstraint or + # ExcludeConstraint. While there is a separate + # "max_identifier_length" for indexes, PostgreSQL uses the same + # length for all objects so we can use + # truncate_and_render_constraint_name + target_text = ( + "ON CONSTRAINT %s" + % self.preparer.truncate_and_render_constraint_name( + clause.constraint_target + ) + ) + elif clause.inferred_target_elements is not None: + target_text = "(%s)" % ", ".join( + ( + self.preparer.quote(c) + if isinstance(c, util.string_types) + else self.process(c, include_table=False, use_schema=False) + ) + for c in clause.inferred_target_elements + ) + if clause.inferred_target_whereclause is not None: + target_text += " WHERE %s" % self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False, + ) + else: + target_text = "" + + return target_text + + @util.memoized_property + def _is_safe_for_fast_insert_values_helper(self): + # don't allow fast executemany if _post_values_clause is + # present and is not an OnConflictDoNothing. what this means + # concretely is that the + # "fast insert executemany helper" won't be used, in other + # words we won't convert "executemany()" of many parameter + # sets into a single INSERT with many elements in VALUES. + # We can't apply that optimization safely if for example the + # statement includes a clause like "ON CONFLICT DO UPDATE" + + return self.insert_single_values_expr is not None and ( + self.statement._post_values_clause is None + or isinstance( + self.statement._post_values_clause, dml.OnConflictDoNothing + ) + ) + + def visit_on_conflict_do_nothing(self, on_conflict, **kw): + + target_text = self._on_conflict_target(on_conflict, **kw) + + if target_text: + return "ON CONFLICT %s DO NOTHING" % target_text + else: + return "ON CONFLICT DO NOTHING" + + def visit_on_conflict_do_update(self, on_conflict, **kw): + + clause = on_conflict + + target_text = self._on_conflict_target(on_conflict, **kw) + + action_set_ops = [] + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + + insert_statement = self.stack[-1]["selectable"] + cols = insert_statement.table.c + for c in cols: + col_key = c.key + + if col_key in set_parameters: + value = set_parameters.pop(col_key) + elif c in set_parameters: + value = set_parameters.pop(c) + else: + continue + + if coercions._is_literal(value): + value = elements.BindParameter(None, value, type_=c.type) + + else: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = self.preparer.quote(c.name) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.current_executable.table.name, + (", ".join("'%s'" % c for c in set_parameters)), + ) + ) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, util.string_types) + else self.process(k, use_schema=False) + ) + value_text = self.process( + coercions.expect(roles.ExpressionElementRole, v), + use_schema=False, + ) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + action_text = ", ".join(action_set_ops) + if clause.update_whereclause is not None: + action_text += " WHERE %s" % self.process( + clause.update_whereclause, include_table=True, use_schema=False + ) + + return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. USING clause specific to PostgreSQL.""" + kw["asfrom"] = True + return "USING " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def fetch_clause(self, select, **kw): + # pg requires parens for non literal clauses. It's also required for + # bind parameters if a ::type casts is used by the driver (asyncpg), + # so it's easiest to just always add it + text = "" + if select._offset_clause is not None: + text += "\n OFFSET (%s) ROWS" % self.process( + select._offset_clause, **kw + ) + if select._fetch_clause is not None: + text += "\n FETCH FIRST (%s)%s ROWS %s" % ( + self.process(select._fetch_clause, **kw), + " PERCENT" if select._fetch_clause_options["percent"] else "", + "WITH TIES" + if select._fetch_clause_options["with_ties"] + else "ONLY", + ) + return text + + +class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + + colspec = self.preparer.format_column(column) + impl_type = column.type.dialect_impl(self.dialect) + if isinstance(impl_type, sqltypes.TypeDecorator): + impl_type = impl_type.impl + + has_identity = ( + column.identity is not None + and self.dialect.supports_identity_columns + ) + + if ( + column.primary_key + and column is column.table._autoincrement_column + and ( + self.dialect.supports_smallserial + or not isinstance(impl_type, sqltypes.SmallInteger) + ) + and not has_identity + and ( + column.default is None + or ( + isinstance(column.default, schema.Sequence) + and column.default.optional + ) + ) + ): + if isinstance(impl_type, sqltypes.BigInteger): + colspec += " BIGSERIAL" + elif isinstance(impl_type, sqltypes.SmallInteger): + colspec += " SMALLSERIAL" + else: + colspec += " SERIAL" + else: + colspec += " " + self.dialect.type_compiler.process( + column.type, + type_expression=column, + identifier_preparer=self.preparer, + ) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if column.computed is not None: + colspec += " " + self.process(column.computed) + if has_identity: + colspec += " " + self.process(column.identity) + + if not column.nullable and not has_identity: + colspec += " NOT NULL" + elif column.nullable and has_identity: + colspec += " NULL" + return colspec + + def _define_constraint_validity(self, constraint): + not_valid = constraint.dialect_options["postgresql"]["not_valid"] + return " NOT VALID" if not_valid else "" + + def visit_check_constraint(self, constraint): + if constraint._type_bound: + typ = list(constraint.columns)[0].type + if ( + isinstance(typ, sqltypes.ARRAY) + and isinstance(typ.item_type, sqltypes.Enum) + and not typ.item_type.native_enum + ): + raise exc.CompileError( + "PostgreSQL dialect cannot produce the CHECK constraint " + "for ARRAY of non-native ENUM; please specify " + "create_constraint=False on this Enum datatype." + ) + + text = super(PGDDLCompiler, self).visit_check_constraint(constraint) + text += self._define_constraint_validity(constraint) + return text + + def visit_foreign_key_constraint(self, constraint): + text = super(PGDDLCompiler, self).visit_foreign_key_constraint( + constraint + ) + text += self._define_constraint_validity(constraint) + return text + + def visit_drop_table_comment(self, drop): + return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( + drop.element + ) + + def visit_create_enum_type(self, create): + type_ = create.element + + return "CREATE TYPE %s AS ENUM (%s)" % ( + self.preparer.format_type(type_), + ", ".join( + self.sql_compiler.process(sql.literal(e), literal_binds=True) + for e in type_.enums + ), + ) + + def visit_drop_enum_type(self, drop): + type_ = drop.element + + return "DROP TYPE %s" % (self.preparer.format_type(type_)) + + def visit_create_index(self, create): + preparer = self.preparer + index = create.element + self._verify_index_table(index) + text = "CREATE " + if index.unique: + text += "UNIQUE " + text += "INDEX " + + if self.dialect._supports_create_index_concurrently: + concurrently = index.dialect_options["postgresql"]["concurrently"] + if concurrently: + text += "CONCURRENTLY " + + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s " % ( + self._prepared_index_name(index, include_schema=False), + preparer.format_table(index.table), + ) + + using = index.dialect_options["postgresql"]["using"] + if using: + text += ( + "USING %s " + % self.preparer.validate_sql_phrase(using, IDX_USING).lower() + ) + + ops = index.dialect_options["postgresql"]["ops"] + text += "(%s)" % ( + ", ".join( + [ + self.sql_compiler.process( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, + include_table=False, + literal_binds=True, + ) + + ( + (" " + ops[expr.key]) + if hasattr(expr, "key") and expr.key in ops + else "" + ) + for expr in index.expressions + ] + ) + ) + + includeclause = index.dialect_options["postgresql"]["include"] + if includeclause: + inclusions = [ + index.table.c[col] + if isinstance(col, util.string_types) + else col + for col in includeclause + ] + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) + + withclause = index.dialect_options["postgresql"]["with"] + if withclause: + text += " WITH (%s)" % ( + ", ".join( + [ + "%s = %s" % storage_parameter + for storage_parameter in withclause.items() + ] + ) + ) + + tablespace_name = index.dialect_options["postgresql"]["tablespace"] + if tablespace_name: + text += " TABLESPACE %s" % preparer.quote(tablespace_name) + + whereclause = index.dialect_options["postgresql"]["where"] + if whereclause is not None: + whereclause = coercions.expect( + roles.DDLExpressionRole, whereclause + ) + + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def visit_drop_index(self, drop): + index = drop.element + + text = "\nDROP INDEX " + + if self.dialect._supports_drop_index_concurrently: + concurrently = index.dialect_options["postgresql"]["concurrently"] + if concurrently: + text += "CONCURRENTLY " + + if drop.if_exists: + text += "IF EXISTS " + + text += self._prepared_index_name(index, include_schema=True) + return text + + def visit_exclude_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) + elements = [] + for expr, name, op in constraint._render_exprs: + kw["include_table"] = False + exclude_element = self.sql_compiler.process(expr, **kw) + ( + (" " + constraint.ops[expr.key]) + if hasattr(expr, "key") and expr.key in constraint.ops + else "" + ) + + elements.append("%s WITH %s" % (exclude_element, op)) + text += "EXCLUDE USING %s (%s)" % ( + self.preparer.validate_sql_phrase( + constraint.using, IDX_USING + ).lower(), + ", ".join(elements), + ) + if constraint.where is not None: + text += " WHERE (%s)" % self.sql_compiler.process( + constraint.where, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def post_create_table(self, table): + table_opts = [] + pg_opts = table.dialect_options["postgresql"] + + inherits = pg_opts.get("inherits") + if inherits is not None: + if not isinstance(inherits, (list, tuple)): + inherits = (inherits,) + table_opts.append( + "\n INHERITS ( " + + ", ".join(self.preparer.quote(name) for name in inherits) + + " )" + ) + + if pg_opts["partition_by"]: + table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"]) + + if pg_opts["with_oids"] is True: + table_opts.append("\n WITH OIDS") + elif pg_opts["with_oids"] is False: + table_opts.append("\n WITHOUT OIDS") + + if pg_opts["on_commit"]: + on_commit_options = pg_opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) + + if pg_opts["tablespace"]: + tablespace_name = pg_opts["tablespace"] + table_opts.append( + "\n TABLESPACE %s" % self.preparer.quote(tablespace_name) + ) + + return "".join(table_opts) + + def visit_computed_column(self, generated): + if generated.persisted is False: + raise exc.CompileError( + "PostrgreSQL computed columns do not support 'virtual' " + "persistence; set the 'persisted' flag to None or True for " + "PostgreSQL support." + ) + + return "GENERATED ALWAYS AS (%s) STORED" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + + def visit_create_sequence(self, create, **kw): + prefix = None + if create.element.data_type is not None: + prefix = " AS %s" % self.type_compiler.process( + create.element.data_type + ) + + return super(PGDDLCompiler, self).visit_create_sequence( + create, prefix=prefix, **kw + ) + + +class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_TSVECTOR(self, type_, **kw): + return "TSVECTOR" + + def visit_INET(self, type_, **kw): + return "INET" + + def visit_CIDR(self, type_, **kw): + return "CIDR" + + def visit_MACADDR(self, type_, **kw): + return "MACADDR" + + def visit_MONEY(self, type_, **kw): + return "MONEY" + + def visit_OID(self, type_, **kw): + return "OID" + + def visit_REGCLASS(self, type_, **kw): + return "REGCLASS" + + def visit_FLOAT(self, type_, **kw): + if not type_.precision: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {"precision": type_.precision} + + def visit_DOUBLE_PRECISION(self, type_, **kw): + return "DOUBLE PRECISION" + + def visit_BIGINT(self, type_, **kw): + return "BIGINT" + + def visit_HSTORE(self, type_, **kw): + return "HSTORE" + + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_JSONB(self, type_, **kw): + return "JSONB" + + def visit_INT4RANGE(self, type_, **kw): + return "INT4RANGE" + + def visit_INT8RANGE(self, type_, **kw): + return "INT8RANGE" + + def visit_NUMRANGE(self, type_, **kw): + return "NUMRANGE" + + def visit_DATERANGE(self, type_, **kw): + return "DATERANGE" + + def visit_TSRANGE(self, type_, **kw): + return "TSRANGE" + + def visit_TSTZRANGE(self, type_, **kw): + return "TSTZRANGE" + + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + def visit_enum(self, type_, **kw): + if not type_.native_enum or not self.dialect.supports_native_enum: + return super(PGTypeCompiler, self).visit_enum(type_, **kw) + else: + return self.visit_ENUM(type_, **kw) + + def visit_ENUM(self, type_, identifier_preparer=None, **kw): + if identifier_preparer is None: + identifier_preparer = self.dialect.identifier_preparer + return identifier_preparer.format_type(type_) + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP%s %s" % ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", + ) + + def visit_TIME(self, type_, **kw): + return "TIME%s %s" % ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", + ) + + def visit_INTERVAL(self, type_, **kw): + text = "INTERVAL" + if type_.fields is not None: + text += " " + type_.fields + if type_.precision is not None: + text += " (%d)" % type_.precision + return text + + def visit_BIT(self, type_, **kw): + if type_.varying: + compiled = "BIT VARYING" + if type_.length is not None: + compiled += "(%d)" % type_.length + else: + compiled = "BIT(%d)" % type_.length + return compiled + + def visit_UUID(self, type_, **kw): + return "UUID" + + def visit_large_binary(self, type_, **kw): + return self.visit_BYTEA(type_, **kw) + + def visit_BYTEA(self, type_, **kw): + return "BYTEA" + + def visit_ARRAY(self, type_, **kw): + + inner = self.process(type_.item_type, **kw) + return re.sub( + r"((?: COLLATE.*)?)$", + ( + r"%s\1" + % ( + "[]" + * (type_.dimensions if type_.dimensions is not None else 1) + ) + ), + inner, + count=1, + ) + + +class PGIdentifierPreparer(compiler.IdentifierPreparer): + + reserved_words = RESERVED_WORDS + + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace( + self.escape_to_quote, self.escape_quote + ) + return value + + def format_type(self, type_, use_schema=True): + if not type_.name: + raise exc.CompileError("PostgreSQL ENUM type requires a name.") + + name = self.quote(type_.name) + effective_schema = self.schema_for_object(type_) + + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): + name = self.quote_schema(effective_schema) + "." + name + return name + + +class PGInspector(reflection.Inspector): + def get_table_oid(self, table_name, schema=None): + """Return the OID for the given table name.""" + + with self._operation_context() as conn: + return self.dialect.get_table_oid( + conn, table_name, schema, info_cache=self.info_cache + ) + + def get_enums(self, schema=None): + """Return a list of ENUM objects. + + Each member is a dictionary containing these fields: + + * name - name of the enum + * schema - the schema name for the enum. + * visible - boolean, whether or not this enum is visible + in the default search path. + * labels - a list of string labels that apply to the enum. + + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to '*' to + indicate load enums for all schemas. + + .. versionadded:: 1.0.0 + + """ + schema = schema or self.default_schema_name + with self._operation_context() as conn: + return self.dialect._load_enums(conn, schema) + + def get_foreign_table_names(self, schema=None): + """Return a list of FOREIGN TABLE names. + + Behavior is similar to that of + :meth:`_reflection.Inspector.get_table_names`, + except that the list is limited to those tables that report a + ``relkind`` value of ``f``. + + .. versionadded:: 1.0.0 + + """ + schema = schema or self.default_schema_name + with self._operation_context() as conn: + return self.dialect._get_foreign_table_names(conn, schema) + + def get_view_names(self, schema=None, include=("plain", "materialized")): + """Return all view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + + :param include: specify which types of views to return. Passed + as a string value (for a single type) or a tuple (for any number + of types). Defaults to ``('plain', 'materialized')``. + + .. versionadded:: 1.1 + + """ + + with self._operation_context() as conn: + return self.dialect.get_view_names( + conn, schema, info_cache=self.info_cache, include=include + ) + + +class CreateEnumType(schema._CreateDropBase): + __visit_name__ = "create_enum_type" + + +class DropEnumType(schema._CreateDropBase): + __visit_name__ = "drop_enum_type" + + +class PGExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "select nextval('%s')" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + def get_insert_default(self, column): + if column.primary_key and column is column.table._autoincrement_column: + if column.server_default and column.server_default.has_argument: + + # pre-execute passive defaults on primary key columns + return self._execute_scalar( + "select %s" % column.server_default.arg, column.type + ) + + elif column.default is None or ( + column.default.is_sequence and column.default.optional + ): + # execute the sequence associated with a SERIAL primary + # key column. for non-primary-key SERIAL, the ID just + # generates server side. + + try: + seq_name = column._postgresql_seq_name + except AttributeError: + tab = column.table.name + col = column.name + tab = tab[0 : 29 + max(0, (29 - len(col)))] + col = col[0 : 29 + max(0, (29 - len(tab)))] + name = "%s_%s_seq" % (tab, col) + column._postgresql_seq_name = seq_name = name + + if column.table is not None: + effective_schema = self.connection.schema_for_object( + column.table + ) + else: + effective_schema = None + + if effective_schema is not None: + exc = 'select nextval(\'"%s"."%s"\')' % ( + effective_schema, + seq_name, + ) + else: + exc = "select nextval('\"%s\"')" % (seq_name,) + + return self._execute_scalar(exc, column.type) + + return super(PGExecutionContext, self).get_insert_default(column) + + def should_autocommit_text(self, statement): + return AUTOCOMMIT_REGEXP.match(statement) + + +class PGReadOnlyConnectionCharacteristic( + characteristics.ConnectionCharacteristic +): + transactional = True + + def reset_characteristic(self, dialect, dbapi_conn): + dialect.set_readonly(dbapi_conn, False) + + def set_characteristic(self, dialect, dbapi_conn, value): + dialect.set_readonly(dbapi_conn, value) + + def get_characteristic(self, dialect, dbapi_conn): + return dialect.get_readonly(dbapi_conn) + + +class PGDeferrableConnectionCharacteristic( + characteristics.ConnectionCharacteristic +): + transactional = True + + def reset_characteristic(self, dialect, dbapi_conn): + dialect.set_deferrable(dbapi_conn, False) + + def set_characteristic(self, dialect, dbapi_conn, value): + dialect.set_deferrable(dbapi_conn, value) + + def get_characteristic(self, dialect, dbapi_conn): + return dialect.get_deferrable(dbapi_conn) + + +class PGDialect(default.DefaultDialect): + name = "postgresql" + supports_statement_cache = True + supports_alter = True + max_identifier_length = 63 + supports_sane_rowcount = True + + supports_native_enum = True + supports_native_boolean = True + supports_smallserial = True + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + + supports_comments = True + supports_default_values = True + + supports_default_metavalue = True + + supports_empty_insert = False + supports_multivalues_insert = True + supports_identity_columns = True + + default_paramstyle = "pyformat" + ischema_names = ischema_names + colspecs = colspecs + + statement_compiler = PGCompiler + ddl_compiler = PGDDLCompiler + type_compiler = PGTypeCompiler + preparer = PGIdentifierPreparer + execution_ctx_cls = PGExecutionContext + inspector = PGInspector + isolation_level = None + + implicit_returning = True + full_returning = True + + connection_characteristics = ( + default.DefaultDialect.connection_characteristics + ) + connection_characteristics = connection_characteristics.union( + { + "postgresql_readonly": PGReadOnlyConnectionCharacteristic(), + "postgresql_deferrable": PGDeferrableConnectionCharacteristic(), + } + ) + + construct_arguments = [ + ( + schema.Index, + { + "using": False, + "include": None, + "where": None, + "ops": {}, + "concurrently": False, + "with": {}, + "tablespace": None, + }, + ), + ( + schema.Table, + { + "ignore_search_path": False, + "tablespace": None, + "partition_by": None, + "with_oids": None, + "on_commit": None, + "inherits": None, + }, + ), + ( + schema.CheckConstraint, + { + "not_valid": False, + }, + ), + ( + schema.ForeignKeyConstraint, + { + "not_valid": False, + }, + ), + ] + + reflection_options = ("postgresql_ignore_search_path",) + + _backslash_escapes = True + _supports_create_index_concurrently = True + _supports_drop_index_concurrently = True + + def __init__( + self, + isolation_level=None, + json_serializer=None, + json_deserializer=None, + **kwargs + ): + default.DefaultDialect.__init__(self, **kwargs) + + # the isolation_level parameter to the PGDialect itself is legacy. + # still works however the execution_options method is the one that + # is documented. + self.isolation_level = isolation_level + self._json_deserializer = json_deserializer + self._json_serializer = json_serializer + + def initialize(self, connection): + super(PGDialect, self).initialize(connection) + + if self.server_version_info <= (8, 2): + self.full_returning = self.implicit_returning = False + + self.supports_native_enum = self.server_version_info >= (8, 3) + if not self.supports_native_enum: + self.colspecs = self.colspecs.copy() + # pop base Enum type + self.colspecs.pop(sqltypes.Enum, None) + # psycopg2, others may have placed ENUM here as well + self.colspecs.pop(ENUM, None) + + # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 + self.supports_smallserial = self.server_version_info >= (9, 2) + + if self.server_version_info < (8, 2): + self._backslash_escapes = False + else: + # ensure this query is not emitted on server version < 8.2 + # as it will fail + std_string = connection.exec_driver_sql( + "show standard_conforming_strings" + ).scalar() + self._backslash_escapes = std_string == "off" + + self._supports_create_index_concurrently = ( + self.server_version_info >= (8, 2) + ) + self._supports_drop_index_concurrently = self.server_version_info >= ( + 9, + 2, + ) + self.supports_identity_columns = self.server_version_info >= (10,) + + def on_connect(self): + if self.isolation_level is not None: + + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + return connect + else: + return None + + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ] + ) + + def set_isolation_level(self, connection, level): + level = level.replace("_", " ") + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + "ISOLATION LEVEL %s" % level + ) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + cursor.execute("show transaction isolation level") + val = cursor.fetchone()[0] + cursor.close() + return val.upper() + + def set_readonly(self, connection, value): + raise NotImplementedError() + + def get_readonly(self, connection): + raise NotImplementedError() + + def set_deferrable(self, connection, value): + raise NotImplementedError() + + def get_deferrable(self, connection): + raise NotImplementedError() + + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.exec_driver_sql("PREPARE TRANSACTION '%s'" % xid) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + if recover: + # FIXME: ugly hack to get out of transaction + # context when committing recoverable transactions + # Must find out a way how to make the dbapi not + # open a transaction. + connection.exec_driver_sql("ROLLBACK") + connection.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid) + connection.exec_driver_sql("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + if recover: + connection.exec_driver_sql("ROLLBACK") + connection.exec_driver_sql("COMMIT PREPARED '%s'" % xid) + connection.exec_driver_sql("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + resultset = connection.execute( + sql.text("SELECT gid FROM pg_prepared_xacts") + ) + return [row[0] for row in resultset] + + def _get_default_schema_name(self, connection): + return connection.exec_driver_sql("select current_schema()").scalar() + + def has_schema(self, connection, schema): + query = ( + "select nspname from pg_namespace " "where lower(nspname)=:schema" + ) + cursor = connection.execute( + sql.text(query).bindparams( + sql.bindparam( + "schema", + util.text_type(schema.lower()), + type_=sqltypes.Unicode, + ) + ) + ) + + return bool(cursor.first()) + + def has_table(self, connection, table_name, schema=None): + self._ensure_has_table_connection(connection) + # seems like case gets folded in pg_class... + if schema is None: + cursor = connection.execute( + sql.text( + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where " + "pg_catalog.pg_table_is_visible(c.oid) " + "and relname=:name" + ).bindparams( + sql.bindparam( + "name", + util.text_type(table_name), + type_=sqltypes.Unicode, + ) + ) + ) + else: + cursor = connection.execute( + sql.text( + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=:schema and " + "relname=:name" + ).bindparams( + sql.bindparam( + "name", + util.text_type(table_name), + type_=sqltypes.Unicode, + ), + sql.bindparam( + "schema", + util.text_type(schema), + type_=sqltypes.Unicode, + ), + ) + ) + return bool(cursor.first()) + + def has_sequence(self, connection, sequence_name, schema=None): + if schema is None: + schema = self.default_schema_name + cursor = connection.execute( + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=:schema and relname=:name" + ).bindparams( + sql.bindparam( + "name", + util.text_type(sequence_name), + type_=sqltypes.Unicode, + ), + sql.bindparam( + "schema", + util.text_type(schema), + type_=sqltypes.Unicode, + ), + ) + ) + + return bool(cursor.first()) + + def has_type(self, connection, type_name, schema=None): + if schema is not None: + query = """ + SELECT EXISTS ( + SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n + WHERE t.typnamespace = n.oid + AND t.typname = :typname + AND n.nspname = :nspname + ) + """ + query = sql.text(query) + else: + query = """ + SELECT EXISTS ( + SELECT * FROM pg_catalog.pg_type t + WHERE t.typname = :typname + AND pg_type_is_visible(t.oid) + ) + """ + query = sql.text(query) + query = query.bindparams( + sql.bindparam( + "typname", util.text_type(type_name), type_=sqltypes.Unicode + ) + ) + if schema is not None: + query = query.bindparams( + sql.bindparam( + "nspname", util.text_type(schema), type_=sqltypes.Unicode + ) + ) + cursor = connection.execute(query) + return bool(cursor.scalar()) + + def _get_server_version_info(self, connection): + v = connection.exec_driver_sql("select pg_catalog.version()").scalar() + m = re.match( + r".*(?:PostgreSQL|EnterpriseDB) " + r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?", + v, + ) + if not m: + raise AssertionError( + "Could not determine version from string '%s'" % v + ) + return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) + + @reflection.cache + def get_table_oid(self, connection, table_name, schema=None, **kw): + """Fetch the oid for schema.table_name. + + Several reflection methods require the table oid. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + table_oid = None + if schema is not None: + schema_where_clause = "n.nspname = :schema" + else: + schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" + query = ( + """ + SELECT c.oid + FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE (%s) + AND c.relname = :table_name AND c.relkind in + ('r', 'v', 'm', 'f', 'p') + """ + % schema_where_clause + ) + # Since we're binding to unicode, table_name and schema_name must be + # unicode. + table_name = util.text_type(table_name) + if schema is not None: + schema = util.text_type(schema) + s = sql.text(query).bindparams(table_name=sqltypes.Unicode) + s = s.columns(oid=sqltypes.Integer) + if schema: + s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode)) + c = connection.execute(s, dict(table_name=table_name, schema=schema)) + table_oid = c.scalar() + if table_oid is None: + raise exc.NoSuchTableError(table_name) + return table_oid + + @reflection.cache + def get_schema_names(self, connection, **kw): + result = connection.execute( + sql.text( + "SELECT nspname FROM pg_namespace " + "WHERE nspname NOT LIKE 'pg_%' " + "ORDER BY nspname" + ).columns(nspname=sqltypes.Unicode) + ) + return [name for name, in result] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + result = connection.execute( + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" + ).columns(relname=sqltypes.Unicode), + dict( + schema=schema + if schema is not None + else self.default_schema_name + ), + ) + return [name for name, in result] + + @reflection.cache + def _get_foreign_table_names(self, connection, schema=None, **kw): + result = connection.execute( + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind = 'f'" + ).columns(relname=sqltypes.Unicode), + dict( + schema=schema + if schema is not None + else self.default_schema_name + ), + ) + return [name for name, in result] + + @reflection.cache + def get_view_names( + self, connection, schema=None, include=("plain", "materialized"), **kw + ): + + include_kind = {"plain": "v", "materialized": "m"} + try: + kinds = [include_kind[i] for i in util.to_list(include)] + except KeyError: + raise ValueError( + "include %r unknown, needs to be a sequence containing " + "one or both of 'plain' and 'materialized'" % (include,) + ) + if not kinds: + raise ValueError( + "empty include, needs to be a sequence containing " + "one or both of 'plain' and 'materialized'" + ) + + result = connection.execute( + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind IN (%s)" + % (", ".join("'%s'" % elem for elem in kinds)) + ).columns(relname=sqltypes.Unicode), + dict( + schema=schema + if schema is not None + else self.default_schema_name + ), + ) + return [name for name, in result] + + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + if not schema: + schema = self.default_schema_name + cursor = connection.execute( + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=:schema" + ).bindparams( + sql.bindparam( + "schema", + util.text_type(schema), + type_=sqltypes.Unicode, + ), + ) + ) + return [row[0] for row in cursor] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + view_def = connection.scalar( + sql.text( + "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relname = :view_name " + "AND c.relkind IN ('v', 'm')" + ).columns(view_def=sqltypes.Unicode), + dict( + schema=schema + if schema is not None + else self.default_schema_name, + view_name=view_name, + ), + ) + return view_def + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + generated = ( + "a.attgenerated as generated" + if self.server_version_info >= (12,) + else "NULL as generated" + ) + if self.server_version_info >= (10,): + # a.attidentity != '' is required or it will reflect also + # serial columns as identity. + identity = """\ + (SELECT json_build_object( + 'always', a.attidentity = 'a', + 'start', s.seqstart, + 'increment', s.seqincrement, + 'minvalue', s.seqmin, + 'maxvalue', s.seqmax, + 'cache', s.seqcache, + 'cycle', s.seqcycle) + FROM pg_catalog.pg_sequence s + JOIN pg_catalog.pg_class c on s.seqrelid = c."oid" + WHERE c.relkind = 'S' + AND a.attidentity != '' + AND s.seqrelid = pg_catalog.pg_get_serial_sequence( + a.attrelid::regclass::text, a.attname + )::regclass::oid + ) as identity_options\ + """ + else: + identity = "NULL as identity_options" + + SQL_COLS = """ + SELECT a.attname, + pg_catalog.format_type(a.atttypid, a.atttypmod), + ( + SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) + FROM pg_catalog.pg_attrdef d + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum + AND a.atthasdef + ) AS DEFAULT, + a.attnotnull, + a.attrelid as table_oid, + pgd.description as comment, + %s, + %s + FROM pg_catalog.pg_attribute a + LEFT JOIN pg_catalog.pg_description pgd ON ( + pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum) + WHERE a.attrelid = :table_oid + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + """ % ( + generated, + identity, + ) + s = ( + sql.text(SQL_COLS) + .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer)) + .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode) + ) + c = connection.execute(s, dict(table_oid=table_oid)) + rows = c.fetchall() + + # dictionary with (name, ) if default search path or (schema, name) + # as keys + domains = self._load_domains(connection) + + # dictionary with (name, ) if default search path or (schema, name) + # as keys + enums = dict( + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) + for rec in self._load_enums(connection, schema="*") + ) + + # format columns + columns = [] + + for ( + name, + format_type, + default_, + notnull, + table_oid, + comment, + generated, + identity, + ) in rows: + column_info = self._get_column_info( + name, + format_type, + default_, + notnull, + domains, + enums, + schema, + comment, + generated, + identity, + ) + columns.append(column_info) + return columns + + def _get_column_info( + self, + name, + format_type, + default, + notnull, + domains, + enums, + schema, + comment, + generated, + identity, + ): + def _handle_array_type(attype): + return ( + # strip '[]' from integer[], etc. + re.sub(r"\[\]$", "", attype), + attype.endswith("[]"), + ) + + # strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = re.sub(r"\(.*\)", "", format_type) + + # strip '[]' from integer[], etc. and check if an array + attype, is_array = _handle_array_type(attype) + + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + nullable = not notnull + + charlen = re.search(r"\(([\d,]+)\)", format_type) + if charlen: + charlen = charlen.group(1) + args = re.search(r"\((.*)\)", format_type) + if args and args.group(1): + args = tuple(re.split(r"\s*,\s*", args.group(1))) + else: + args = () + kwargs = {} + + if attype == "numeric": + if charlen: + prec, scale = charlen.split(",") + args = (int(prec), int(scale)) + else: + args = () + elif attype == "double precision": + args = (53,) + elif attype == "integer": + args = () + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype == "bit varying": + kwargs["varying"] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) + if charlen: + kwargs["precision"] = int(charlen) + if field_match: + kwargs["fields"] = field_match.group(1) + attype = "interval" + args = () + elif charlen: + args = (int(charlen),) + + while True: + # looping here to suit nested domains + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif enum_or_domain_key in enums: + enum = enums[enum_or_domain_key] + coltype = ENUM + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + break + elif enum_or_domain_key in domains: + domain = domains[enum_or_domain_key] + attype = domain["attype"] + attype, is_array = _handle_array_type(attype) + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + # A table can't override a not null on the domain, + # but can override nullable + nullable = nullable and domain["nullable"] + if domain["default"] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain["default"] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = self.ischema_names["_array"](coltype) + else: + util.warn( + "Did not recognize type '%s' of column '%s'" % (attype, name) + ) + coltype = sqltypes.NULLTYPE + + # If a zero byte or blank string depending on driver (is also absent + # for older PG versions), then not a generated column. Otherwise, s = + # stored. (Other values might be added in the future.) + if generated not in (None, "", b"\x00"): + computed = dict( + sqltext=default, persisted=generated in ("s", b"s") + ) + default = None + else: + computed = None + + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + if issubclass(coltype._type_affinity, sqltypes.Integer): + autoincrement = True + # the default is related to a Sequence + sch = schema + if "." not in match.group(2) and sch is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = ( + match.group(1) + + ('"%s"' % sch) + + "." + + match.group(2) + + match.group(3) + ) + + column_info = dict( + name=name, + type=coltype, + nullable=nullable, + default=default, + autoincrement=autoincrement or identity is not None, + comment=comment, + ) + if computed is not None: + column_info["computed"] = computed + if identity is not None: + column_info["identity"] = identity + return column_info + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + if self.server_version_info < (8, 4): + PK_SQL = """ + SELECT a.attname + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_attribute a + on t.oid=a.attrelid AND %s + WHERE + t.oid = :table_oid and ix.indisprimary = 't' + ORDER BY a.attnum + """ % self._pg_index_any( + "a.attnum", "ix.indkey" + ) + + else: + # unnest() and generate_subscripts() both introduced in + # version 8.4 + PK_SQL = """ + SELECT a.attname + FROM pg_attribute a JOIN ( + SELECT unnest(ix.indkey) attnum, + generate_subscripts(ix.indkey, 1) ord + FROM pg_index ix + WHERE ix.indrelid = :table_oid AND ix.indisprimary + ) k ON a.attnum=k.attnum + WHERE a.attrelid = :table_oid + ORDER BY k.ord + """ + t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode) + c = connection.execute(t, dict(table_oid=table_oid)) + cols = [r[0] for r in c.fetchall()] + + PK_CONS_SQL = """ + SELECT conname + FROM pg_catalog.pg_constraint r + WHERE r.conrelid = :table_oid AND r.contype = 'p' + ORDER BY 1 + """ + t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode) + c = connection.execute(t, dict(table_oid=table_oid)) + name = c.scalar() + + return {"constrained_columns": cols, "name": name} + + @reflection.cache + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + postgresql_ignore_search_path=False, + **kw + ): + preparer = self.identifier_preparer + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + FK_SQL = """ + SELECT r.conname, + pg_catalog.pg_get_constraintdef(r.oid, true) as condef, + n.nspname as conschema + FROM pg_catalog.pg_constraint r, + pg_namespace n, + pg_class c + + WHERE r.conrelid = :table AND + r.contype = 'f' AND + c.oid = confrelid AND + n.oid = c.relnamespace + ORDER BY 1 + """ + # https://www.postgresql.org/docs/9.0/static/sql-createtable.html + FK_REGEX = re.compile( + r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" + r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" + r"[\s]?(ON UPDATE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(ON DELETE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?" + r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" + ) + + t = sql.text(FK_SQL).columns( + conname=sqltypes.Unicode, condef=sqltypes.Unicode + ) + c = connection.execute(t, dict(table=table_oid)) + fkeys = [] + for conname, condef, conschema in c.fetchall(): + m = re.search(FK_REGEX, condef).groups() + + ( + constrained_columns, + referred_schema, + referred_table, + referred_columns, + _, + match, + _, + onupdate, + _, + ondelete, + deferrable, + _, + initially, + ) = m + + if deferrable is not None: + deferrable = True if deferrable == "DEFERRABLE" else False + constrained_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s*", constrained_columns) + ] + + if postgresql_ignore_search_path: + # when ignoring search path, we use the actual schema + # provided it isn't the "default" schema + if conschema != self.default_schema_name: + referred_schema = conschema + else: + referred_schema = schema + elif referred_schema: + # referred_schema is the schema that we regexp'ed from + # pg_get_constraintdef(). If the schema is in the search + # path, pg_get_constraintdef() will give us None. + referred_schema = preparer._unquote_identifier(referred_schema) + elif schema is not None and schema == conschema: + # If the actual schema matches the schema of the table + # we're reflecting, then we will use that. + referred_schema = schema + + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s", referred_columns) + ] + options = { + k: v + for k, v in [ + ("onupdate", onupdate), + ("ondelete", ondelete), + ("initially", initially), + ("deferrable", deferrable), + ("match", match), + ] + if v is not None and v != "NO ACTION" + } + fkey_d = { + "name": conname, + "constrained_columns": constrained_columns, + "referred_schema": referred_schema, + "referred_table": referred_table, + "referred_columns": referred_columns, + "options": options, + } + fkeys.append(fkey_d) + return fkeys + + def _pg_index_any(self, col, compare_to): + if self.server_version_info < (8, 1): + # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us + # "In CVS tip you could replace this with "attnum = ANY (indkey)". + # Unfortunately, most array support doesn't work on int2vector in + # pre-8.1 releases, so I think you're kinda stuck with the above + # for now. + # regards, tom lane" + return "(%s)" % " OR ".join( + "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10) + ) + else: + return "%s = ANY(%s)" % (col, compare_to) + + @reflection.cache + def get_indexes(self, connection, table_name, schema, **kw): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + # cast indkey as varchar since it's an int2vector, + # returned as a list by some drivers such as pypostgresql + + if self.server_version_info < (8, 5): + IDX_SQL = """ + SELECT + i.relname as relname, + ix.indisunique, ix.indexprs, ix.indpred, + a.attname, a.attnum, NULL, ix.indkey%s, + %s, %s, am.amname, + NULL as indnkeyatts + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_class i on i.oid = ix.indexrelid + left outer join + pg_attribute a + on t.oid = a.attrelid and %s + left outer join + pg_am am + on i.relam = am.oid + WHERE + t.relkind IN ('r', 'v', 'f', 'm') + and t.oid = :table_oid + and ix.indisprimary = 'f' + ORDER BY + t.relname, + i.relname + """ % ( + # version 8.3 here was based on observing the + # cast does not work in PG 8.2.4, does work in 8.3.0. + # nothing in PG changelogs regarding this. + "::varchar" if self.server_version_info >= (8, 3) else "", + "ix.indoption::varchar" + if self.server_version_info >= (8, 3) + else "NULL", + "i.reloptions" + if self.server_version_info >= (8, 2) + else "NULL", + self._pg_index_any("a.attnum", "ix.indkey"), + ) + else: + IDX_SQL = """ + SELECT + i.relname as relname, + ix.indisunique, ix.indexprs, + a.attname, a.attnum, c.conrelid, ix.indkey::varchar, + ix.indoption::varchar, i.reloptions, am.amname, + pg_get_expr(ix.indpred, ix.indrelid), + %s as indnkeyatts + FROM + pg_class t + join pg_index ix on t.oid = ix.indrelid + join pg_class i on i.oid = ix.indexrelid + left outer join + pg_attribute a + on t.oid = a.attrelid and a.attnum = ANY(ix.indkey) + left outer join + pg_constraint c + on (ix.indrelid = c.conrelid and + ix.indexrelid = c.conindid and + c.contype in ('p', 'u', 'x')) + left outer join + pg_am am + on i.relam = am.oid + WHERE + t.relkind IN ('r', 'v', 'f', 'm', 'p') + and t.oid = :table_oid + and ix.indisprimary = 'f' + ORDER BY + t.relname, + i.relname + """ % ( + "ix.indnkeyatts" + if self.server_version_info >= (11, 0) + else "NULL", + ) + + t = sql.text(IDX_SQL).columns( + relname=sqltypes.Unicode, attname=sqltypes.Unicode + ) + c = connection.execute(t, dict(table_oid=table_oid)) + + indexes = defaultdict(lambda: defaultdict(dict)) + + sv_idx_name = None + for row in c.fetchall(): + ( + idx_name, + unique, + expr, + col, + col_num, + conrelid, + idx_key, + idx_option, + options, + amname, + filter_definition, + indnkeyatts, + ) = row + + if expr: + if idx_name != sv_idx_name: + util.warn( + "Skipped unsupported reflection of " + "expression-based index %s" % idx_name + ) + sv_idx_name = idx_name + continue + + has_idx = idx_name in indexes + index = indexes[idx_name] + if col is not None: + index["cols"][col_num] = col + if not has_idx: + idx_keys = idx_key.split() + # "The number of key columns in the index, not counting any + # included columns, which are merely stored and do not + # participate in the index semantics" + if indnkeyatts and idx_keys[indnkeyatts:]: + # this is a "covering index" which has INCLUDE columns + # as well as regular index columns + inc_keys = idx_keys[indnkeyatts:] + idx_keys = idx_keys[:indnkeyatts] + else: + inc_keys = [] + + index["key"] = [int(k.strip()) for k in idx_keys] + index["inc"] = [int(k.strip()) for k in inc_keys] + + # (new in pg 8.3) + # "pg_index.indoption" is list of ints, one per column/expr. + # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST + sorting = {} + for col_idx, col_flags in enumerate( + (idx_option or "").split() + ): + col_flags = int(col_flags.strip()) + col_sorting = () + # try to set flags only if they differ from PG defaults... + if col_flags & 0x01: + col_sorting += ("desc",) + if not (col_flags & 0x02): + col_sorting += ("nulls_last",) + else: + if col_flags & 0x02: + col_sorting += ("nulls_first",) + if col_sorting: + sorting[col_idx] = col_sorting + if sorting: + index["sorting"] = sorting + + index["unique"] = unique + if conrelid is not None: + index["duplicates_constraint"] = idx_name + if options: + index["options"] = dict( + [option.split("=") for option in options] + ) + + # it *might* be nice to include that this is 'btree' in the + # reflection info. But we don't want an Index object + # to have a ``postgresql_using`` in it that is just the + # default, so for the moment leaving this out. + if amname and amname != "btree": + index["amname"] = amname + + if filter_definition: + index["postgresql_where"] = filter_definition + + result = [] + for name, idx in indexes.items(): + entry = { + "name": name, + "unique": idx["unique"], + "column_names": [idx["cols"][i] for i in idx["key"]], + } + if self.server_version_info >= (11, 0): + # NOTE: this is legacy, this is part of dialect_options now + # as of #7382 + entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]] + if "duplicates_constraint" in idx: + entry["duplicates_constraint"] = idx["duplicates_constraint"] + if "sorting" in idx: + entry["column_sorting"] = dict( + (idx["cols"][idx["key"][i]], value) + for i, value in idx["sorting"].items() + ) + if "include_columns" in entry: + entry.setdefault("dialect_options", {})[ + "postgresql_include" + ] = entry["include_columns"] + if "options" in idx: + entry.setdefault("dialect_options", {})[ + "postgresql_with" + ] = idx["options"] + if "amname" in idx: + entry.setdefault("dialect_options", {})[ + "postgresql_using" + ] = idx["amname"] + if "postgresql_where" in idx: + entry.setdefault("dialect_options", {})[ + "postgresql_where" + ] = idx["postgresql_where"] + result.append(entry) + return result + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + UNIQUE_SQL = """ + SELECT + cons.conname as name, + cons.conkey as key, + a.attnum as col_num, + a.attname as col_name + FROM + pg_catalog.pg_constraint cons + join pg_attribute a + on cons.conrelid = a.attrelid AND + a.attnum = ANY(cons.conkey) + WHERE + cons.conrelid = :table_oid AND + cons.contype = 'u' + """ + + t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode) + c = connection.execute(t, dict(table_oid=table_oid)) + + uniques = defaultdict(lambda: defaultdict(dict)) + for row in c.fetchall(): + uc = uniques[row.name] + uc["key"] = row.key + uc["cols"][row.col_num] = row.col_name + + return [ + {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]} + for name, uc in uniques.items() + ] + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + COMMENT_SQL = """ + SELECT + pgd.description as table_comment + FROM + pg_catalog.pg_description pgd + WHERE + pgd.objsubid = 0 AND + pgd.objoid = :table_oid + """ + + c = connection.execute( + sql.text(COMMENT_SQL), dict(table_oid=table_oid) + ) + return {"text": c.scalar()} + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + CHECK_SQL = """ + SELECT + cons.conname as name, + pg_get_constraintdef(cons.oid) as src + FROM + pg_catalog.pg_constraint cons + WHERE + cons.conrelid = :table_oid AND + cons.contype = 'c' + """ + + c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid)) + + ret = [] + for name, src in c: + # samples: + # "CHECK (((a > 1) AND (a < 5)))" + # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))" + # "CHECK (((a > 1) AND (a < 5))) NOT VALID" + # "CHECK (some_boolean_function(a))" + # "CHECK (((a\n < 1)\n OR\n (a\n >= 5))\n)" + + m = re.match( + r"^CHECK *\((.+)\)( NOT VALID)?$", src, flags=re.DOTALL + ) + if not m: + util.warn("Could not parse CHECK constraint text: %r" % src) + sqltext = "" + else: + sqltext = re.compile( + r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL + ).sub(r"\1", m.group(1)) + entry = {"name": name, "sqltext": sqltext} + if m and m.group(2): + entry["dialect_options"] = {"not_valid": True} + + ret.append(entry) + return ret + + def _load_enums(self, connection, schema=None): + schema = schema or self.default_schema_name + if not self.supports_native_enum: + return {} + + # Load data types for enums: + SQL_ENUMS = """ + SELECT t.typname as "name", + -- no enum defaults in 8.4 at least + -- t.typdefault as "default", + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema", + e.enumlabel as "label" + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid + WHERE t.typtype = 'e' + """ + + if schema != "*": + SQL_ENUMS += "AND n.nspname = :schema " + + # e.oid gives us label order within an enum + SQL_ENUMS += 'ORDER BY "schema", "name", e.oid' + + s = sql.text(SQL_ENUMS).columns( + attname=sqltypes.Unicode, label=sqltypes.Unicode + ) + + if schema != "*": + s = s.bindparams(schema=schema) + + c = connection.execute(s) + + enums = [] + enum_by_name = {} + for enum in c.fetchall(): + key = (enum.schema, enum.name) + if key in enum_by_name: + enum_by_name[key]["labels"].append(enum.label) + else: + enum_by_name[key] = enum_rec = { + "name": enum.name, + "schema": enum.schema, + "visible": enum.visible, + "labels": [], + } + if enum.label is not None: + enum_rec["labels"].append(enum.label) + enums.append(enum_rec) + return enums + + def _load_domains(self, connection): + # Load data types for domains: + SQL_DOMAINS = """ + SELECT t.typname as "name", + pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", + not t.typnotnull as "nullable", + t.typdefault as "default", + pg_catalog.pg_type_is_visible(t.oid) as "visible", + n.nspname as "schema" + FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE t.typtype = 'd' + """ + + s = sql.text(SQL_DOMAINS) + c = connection.execution_options(future_result=True).execute(s) + + domains = {} + for domain in c.mappings(): + domain = domain + # strip (30) from character varying(30) + attype = re.search(r"([^\(]+)", domain["attype"]).group(1) + # 'visible' just means whether or not the domain is in a + # schema that's on the search path -- or not overridden by + # a schema with higher precedence. If it's not visible, + # it will be prefixed with the schema-name when it's used. + if domain["visible"]: + key = (domain["name"],) + else: + key = (domain["schema"], domain["name"]) + + domains[key] = { + "attype": attype, + "nullable": domain["nullable"], + "default": domain["default"], + } + + return domains diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py new file mode 100644 index 0000000..b483774 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -0,0 +1,274 @@ +# postgresql/on_conflict.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from . import ext +from ... import util +from ...sql import coercions +from ...sql import roles +from ...sql import schema +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.expression import alias +from ...util.langhelpers import public_factory + + +__all__ = ("Insert", "insert") + + +class Insert(StandardInsert): + """PostgreSQL-specific implementation of INSERT. + + Adds methods for PG-specific syntaxes such as ON CONFLICT. + + The :class:`_postgresql.Insert` object is created using the + :func:`sqlalchemy.dialects.postgresql.insert` function. + + .. versionadded:: 1.1 + + """ + + stringify_dialect = "postgresql" + inherit_cache = False + + @util.memoized_property + def excluded(self): + """Provide the ``excluded`` namespace for an ON CONFLICT statement + + PG's ON CONFLICT clause allows reference to the row that would + be inserted, known as ``excluded``. This attribute provides + all columns in this row to be referenceable. + + .. tip:: The :attr:`_postgresql.Insert.excluded` attribute is an + instance of :class:`_expression.ColumnCollection`, which provides + an interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.excluded.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.excluded["column name"]`` or + ``stmt.excluded["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` - example of how + to use :attr:`_expression.Insert.excluded` + + """ + return alias(self.table, name="excluded").columns + + _on_conflict_exclusive = _exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already has " + "an ON CONFLICT clause established" + }, + ) + + @_generative + @_on_conflict_exclusive + def on_conflict_do_update( + self, + constraint=None, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): + r""" + Specifies a DO UPDATE SET action for ON CONFLICT clause. + + Either the ``constraint`` or ``index_elements`` argument is + required, but only one of these can be specified. + + :param constraint: + The name of a unique or exclusion constraint on the table, + or the constraint object itself if it has a .name attribute. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + :param set\_: + A dictionary or other mapping object + where the keys are either names of columns in the target table, + or :class:`_schema.Column` objects or other ORM-mapped columns + matching that of the target table, and expressions or literals + as values, specifying the ``SET`` actions to take. + + .. versionadded:: 1.4 The + :paramref:`_postgresql.Insert.on_conflict_do_update.set_` + parameter supports :class:`_schema.Column` objects from the target + :class:`_schema.Table` as keys. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of + UPDATE, unless they are manually specified in the + :paramref:`.Insert.on_conflict_do_update.set_` dictionary. + + :param where: + Optional argument. If present, can be a literal SQL + string or an acceptable expression for a ``WHERE`` clause + that restricts the rows affected by ``DO UPDATE SET``. Rows + not meeting the ``WHERE`` condition will not be updated + (effectively a ``DO NOTHING`` for those rows). + + .. versionadded:: 1.1 + + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` + + """ + self._post_values_clause = OnConflictDoUpdate( + constraint, index_elements, index_where, set_, where + ) + + @_generative + @_on_conflict_exclusive + def on_conflict_do_nothing( + self, constraint=None, index_elements=None, index_where=None + ): + """ + Specifies a DO NOTHING action for ON CONFLICT clause. + + The ``constraint`` and ``index_elements`` arguments + are optional, but only one of these can be specified. + + :param constraint: + The name of a unique or exclusion constraint on the table, + or the constraint object itself if it has a .name attribute. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + .. versionadded:: 1.1 + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` + + """ + self._post_values_clause = OnConflictDoNothing( + constraint, index_elements, index_where + ) + + +insert = public_factory( + Insert, ".dialects.postgresql.insert", ".dialects.postgresql.Insert" +) + + +class OnConflictClause(ClauseElement): + stringify_dialect = "postgresql" + + def __init__(self, constraint=None, index_elements=None, index_where=None): + + if constraint is not None: + if not isinstance(constraint, util.string_types) and isinstance( + constraint, + (schema.Index, schema.Constraint, ext.ExcludeConstraint), + ): + constraint = getattr(constraint, "name") or constraint + + if constraint is not None: + if index_elements is not None: + raise ValueError( + "'constraint' and 'index_elements' are mutually exclusive" + ) + + if isinstance(constraint, util.string_types): + self.constraint_target = constraint + self.inferred_target_elements = None + self.inferred_target_whereclause = None + elif isinstance(constraint, schema.Index): + index_elements = constraint.expressions + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) + elif isinstance(constraint, ext.ExcludeConstraint): + index_elements = constraint.columns + index_where = constraint.where + else: + index_elements = constraint.columns + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) + + if index_elements is not None: + self.constraint_target = None + self.inferred_target_elements = index_elements + self.inferred_target_whereclause = index_where + elif constraint is None: + self.constraint_target = ( + self.inferred_target_elements + ) = self.inferred_target_whereclause = None + + +class OnConflictDoNothing(OnConflictClause): + __visit_name__ = "on_conflict_do_nothing" + + +class OnConflictDoUpdate(OnConflictClause): + __visit_name__ = "on_conflict_do_update" + + def __init__( + self, + constraint=None, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): + super(OnConflictDoUpdate, self).__init__( + constraint=constraint, + index_elements=index_elements, + index_where=index_where, + ) + + if ( + self.inferred_target_elements is None + and self.constraint_target is None + ): + raise ValueError( + "Either constraint or index_elements, " + "but not both, must be specified unless DO NOTHING" + ) + + if isinstance(set_, dict): + if not set_: + raise ValueError("set parameter dictionary must not be empty") + elif isinstance(set_, ColumnCollection): + set_ = dict(set_) + else: + raise ValueError( + "set parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update_values_to_set = [ + (coercions.expect(roles.DMLColumnRole, key), value) + for key, value in set_.items() + ] + self.update_whereclause = where diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py new file mode 100644 index 0000000..9e52ee1 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -0,0 +1,277 @@ +# postgresql/ext.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .array import ARRAY +from ... import util +from ...sql import coercions +from ...sql import elements +from ...sql import expression +from ...sql import functions +from ...sql import roles +from ...sql import schema +from ...sql.schema import ColumnCollectionConstraint + + +class aggregate_order_by(expression.ColumnElement): + """Represent a PostgreSQL aggregate order by expression. + + E.g.:: + + from sqlalchemy.dialects.postgresql import aggregate_order_by + expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc())) + stmt = select(expr) + + would represent the expression:: + + SELECT array_agg(a ORDER BY b DESC) FROM table; + + Similarly:: + + expr = func.string_agg( + table.c.a, + aggregate_order_by(literal_column("','"), table.c.a) + ) + stmt = select(expr) + + Would represent:: + + SELECT string_agg(a, ',' ORDER BY a) FROM table; + + .. versionadded:: 1.1 + + .. versionchanged:: 1.2.13 - the ORDER BY argument may be multiple terms + + .. seealso:: + + :class:`_functions.array_agg` + + """ + + __visit_name__ = "aggregate_order_by" + + stringify_dialect = "postgresql" + inherit_cache = False + + def __init__(self, target, *order_by): + self.target = coercions.expect(roles.ExpressionElementRole, target) + self.type = self.target.type + + _lob = len(order_by) + if _lob == 0: + raise TypeError("at least one ORDER BY element is required") + elif _lob == 1: + self.order_by = coercions.expect( + roles.ExpressionElementRole, order_by[0] + ) + else: + self.order_by = elements.ClauseList( + *order_by, _literal_as_text_role=roles.ExpressionElementRole + ) + + def self_group(self, against=None): + return self + + def get_children(self, **kwargs): + return self.target, self.order_by + + def _copy_internals(self, clone=elements._clone, **kw): + self.target = clone(self.target, **kw) + self.order_by = clone(self.order_by, **kw) + + @property + def _from_objects(self): + return self.target._from_objects + self.order_by._from_objects + + +class ExcludeConstraint(ColumnCollectionConstraint): + """A table-level EXCLUDE constraint. + + Defines an EXCLUDE constraint as described in the `PostgreSQL + documentation`__. + + __ https://www.postgresql.org/docs/current/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE + + """ # noqa + + __visit_name__ = "exclude_constraint" + + where = None + inherit_cache = False + + create_drop_stringify_dialect = "postgresql" + + @elements._document_text_coercion( + "where", + ":class:`.ExcludeConstraint`", + ":paramref:`.ExcludeConstraint.where`", + ) + def __init__(self, *elements, **kw): + r""" + Create an :class:`.ExcludeConstraint` object. + + E.g.:: + + const = ExcludeConstraint( + (Column('period'), '&&'), + (Column('group'), '='), + where=(Column('group') != 'some group'), + ops={'group': 'my_operator_class'} + ) + + The constraint is normally embedded into the :class:`_schema.Table` + construct + directly, or added later using :meth:`.append_constraint`:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('period', TSRANGE()), + Column('group', String) + ) + + some_table.append_constraint( + ExcludeConstraint( + (some_table.c.period, '&&'), + (some_table.c.group, '='), + where=some_table.c.group != 'some group', + name='some_table_excl_const', + ops={'group': 'my_operator_class'} + ) + ) + + :param \*elements: + + A sequence of two tuples of the form ``(column, operator)`` where + "column" is a SQL expression element or a raw SQL string, most + typically a :class:`_schema.Column` object, + and "operator" is a string + containing the operator to use. In order to specify a column name + when a :class:`_schema.Column` object is not available, + while ensuring + that any necessary quoting rules take effect, an ad-hoc + :class:`_schema.Column` or :func:`_expression.column` + object should be + used. + + :param name: + Optional, the in-database name of this constraint. + + :param deferrable: + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + :param initially: + Optional string. If set, emit INITIALLY <value> when issuing DDL + for this constraint. + + :param using: + Optional string. If set, emit USING <index_method> when issuing DDL + for this constraint. Defaults to 'gist'. + + :param where: + Optional SQL expression construct or literal SQL string. + If set, emit WHERE <predicate> when issuing DDL + for this constraint. + + :param ops: + Optional dictionary. Used to define operator classes for the + elements; works the same way as that of the + :ref:`postgresql_ops <postgresql_operator_classes>` + parameter specified to the :class:`_schema.Index` construct. + + .. versionadded:: 1.3.21 + + .. seealso:: + + :ref:`postgresql_operator_classes` - general description of how + PostgreSQL operator classes are specified. + + """ + columns = [] + render_exprs = [] + self.operators = {} + + expressions, operators = zip(*elements) + + for (expr, column, strname, add_element), operator in zip( + coercions.expect_col_expression_collection( + roles.DDLConstraintColumnRole, expressions + ), + operators, + ): + if add_element is not None: + columns.append(add_element) + + name = column.name if column is not None else strname + + if name is not None: + # backwards compat + self.operators[name] = operator + + render_exprs.append((expr, name, operator)) + + self._render_exprs = render_exprs + + ColumnCollectionConstraint.__init__( + self, + *columns, + name=kw.get("name"), + deferrable=kw.get("deferrable"), + initially=kw.get("initially") + ) + self.using = kw.get("using", "gist") + where = kw.get("where") + if where is not None: + self.where = coercions.expect(roles.StatementOptionRole, where) + + self.ops = kw.get("ops", {}) + + def _set_parent(self, table, **kw): + super(ExcludeConstraint, self)._set_parent(table) + + self._render_exprs = [ + ( + expr if isinstance(expr, elements.ClauseElement) else colexpr, + name, + operator, + ) + for (expr, name, operator), colexpr in util.zip_longest( + self._render_exprs, self.columns + ) + ] + + def _copy(self, target_table=None, **kw): + elements = [ + ( + schema._copy_expression(expr, self.parent, target_table), + self.operators[expr.name], + ) + for expr in self.columns + ] + c = self.__class__( + *elements, + name=self.name, + deferrable=self.deferrable, + initially=self.initially, + where=self.where, + using=self.using + ) + c.dispatch._update(self.dispatch) + return c + + +def array_agg(*arg, **kw): + """PostgreSQL-specific form of :class:`_functions.array_agg`, ensures + return type is :class:`_postgresql.ARRAY` and not + the plain :class:`_types.ARRAY`, unless an explicit ``type_`` + is passed. + + .. versionadded:: 1.1 + + """ + kw["_default_array_type"] = ARRAY + return functions.func.array_agg(*arg, **kw) diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py new file mode 100644 index 0000000..29800d2 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -0,0 +1,455 @@ +# postgresql/hstore.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +import re + +from .array import ARRAY +from ... import types as sqltypes +from ... import util +from ...sql import functions as sqlfunc +from ...sql import operators + + +__all__ = ("HSTORE", "hstore") + +idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] + +GETITEM = operators.custom_op( + "->", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_KEY = operators.custom_op( + "?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_ALL = operators.custom_op( + "?&", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_ANY = operators.custom_op( + "?|", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +CONTAINS = operators.custom_op( + "@>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +CONTAINED_BY = operators.custom_op( + "<@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + + +class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): + """Represent the PostgreSQL HSTORE type. + + The :class:`.HSTORE` type stores dictionaries containing strings, e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', HSTORE) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + :class:`.HSTORE` provides for a wide range of operations, including: + + * Index operations:: + + data_table.c.data['some key'] == 'some value' + + * Containment operations:: + + data_table.c.data.has_key('some key') + + data_table.c.data.has_all(['one', 'two', 'three']) + + * Concatenation:: + + data_table.c.data + {"k1": "v1"} + + For a full list of special methods see + :class:`.HSTORE.comparator_factory`. + + For usage with the SQLAlchemy ORM, it may be desirable to combine + the usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary + now part of the :mod:`sqlalchemy.ext.mutable` + extension. This extension will allow "in-place" changes to the + dictionary, e.g. addition of new keys or replacement/removal of existing + keys to/from the current dictionary, to produce events which will be + detected by the unit of work:: + + from sqlalchemy.ext.mutable import MutableDict + + class MyClass(Base): + __tablename__ = 'data_table' + + id = Column(Integer, primary_key=True) + data = Column(MutableDict.as_mutable(HSTORE)) + + my_object = session.query(MyClass).one() + + # in-place mutation, requires Mutable extension + # in order for the ORM to detect + my_object.data['some_key'] = 'some value' + + session.commit() + + When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM + will not be alerted to any changes to the contents of an existing + dictionary, unless that dictionary value is re-assigned to the + HSTORE-attribute itself, thus generating a change event. + + .. seealso:: + + :class:`.hstore` - render the PostgreSQL ``hstore()`` function. + + + """ + + __visit_name__ = "HSTORE" + hashable = False + text_type = sqltypes.Text() + + def __init__(self, text_type=None): + """Construct a new :class:`.HSTORE`. + + :param text_type: the type that should be used for indexed values. + Defaults to :class:`_types.Text`. + + .. versionadded:: 1.1.0 + + """ + if text_type is not None: + self.text_type = text_type + + class Comparator( + sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator + ): + """Define comparison operations for :class:`.HSTORE`.""" + + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. + """ + return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) + + def has_all(self, other): + """Boolean expression. Test for presence of all keys in jsonb""" + return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) + + def has_any(self, other): + """Boolean expression. Test for presence of any key in jsonb""" + return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) + + def contains(self, other, **kwargs): + """Boolean expression. Test if keys (or array) are a superset + of/contained the keys of the argument jsonb expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if keys are a proper subset of the + keys of the argument jsonb expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def _setup_getitem(self, index): + return GETITEM, index, self.type.text_type + + def defined(self, key): + """Boolean expression. Test for presence of a non-NULL value for + the key. Note that the key may be a SQLA expression. + """ + return _HStoreDefinedFunction(self.expr, key) + + def delete(self, key): + """HStore expression. Returns the contents of this hstore with the + given key deleted. Note that the key may be a SQLA expression. + """ + if isinstance(key, dict): + key = _serialize_hstore(key) + return _HStoreDeleteFunction(self.expr, key) + + def slice(self, array): + """HStore expression. Returns a subset of an hstore defined by + array of keys. + """ + return _HStoreSliceFunction(self.expr, array) + + def keys(self): + """Text array expression. Returns array of keys.""" + return _HStoreKeysFunction(self.expr) + + def vals(self): + """Text array expression. Returns array of values.""" + return _HStoreValsFunction(self.expr) + + def array(self): + """Text array expression. Returns array of alternating keys and + values. + """ + return _HStoreArrayFunction(self.expr) + + def matrix(self): + """Text array expression. Returns array of [key, value] pairs.""" + return _HStoreMatrixFunction(self.expr) + + comparator_factory = Comparator + + def bind_processor(self, dialect): + if util.py2k: + encoding = dialect.encoding + + def process(value): + if isinstance(value, dict): + return _serialize_hstore(value).encode(encoding) + else: + return value + + else: + + def process(value): + if isinstance(value, dict): + return _serialize_hstore(value) + else: + return value + + return process + + def result_processor(self, dialect, coltype): + if util.py2k: + encoding = dialect.encoding + + def process(value): + if value is not None: + return _parse_hstore(value.decode(encoding)) + else: + return value + + else: + + def process(value): + if value is not None: + return _parse_hstore(value) + else: + return value + + return process + + +class hstore(sqlfunc.GenericFunction): + """Construct an hstore value within a SQL expression using the + PostgreSQL ``hstore()`` function. + + The :class:`.hstore` function accepts one or two arguments as described + in the PostgreSQL documentation. + + E.g.:: + + from sqlalchemy.dialects.postgresql import array, hstore + + select(hstore('key1', 'value1')) + + select( + hstore( + array(['key1', 'key2', 'key3']), + array(['value1', 'value2', 'value3']) + ) + ) + + .. seealso:: + + :class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype. + + """ + + type = HSTORE + name = "hstore" + inherit_cache = True + + +class _HStoreDefinedFunction(sqlfunc.GenericFunction): + type = sqltypes.Boolean + name = "defined" + inherit_cache = True + + +class _HStoreDeleteFunction(sqlfunc.GenericFunction): + type = HSTORE + name = "delete" + inherit_cache = True + + +class _HStoreSliceFunction(sqlfunc.GenericFunction): + type = HSTORE + name = "slice" + inherit_cache = True + + +class _HStoreKeysFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "akeys" + inherit_cache = True + + +class _HStoreValsFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "avals" + inherit_cache = True + + +class _HStoreArrayFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "hstore_to_array" + inherit_cache = True + + +class _HStoreMatrixFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "hstore_to_matrix" + inherit_cache = True + + +# +# parsing. note that none of this is used with the psycopg2 backend, +# which provides its own native extensions. +# + +# My best guess at the parsing rules of hstore literals, since no formal +# grammar is given. This is mostly reverse engineered from PG's input parser +# behavior. +HSTORE_PAIR_RE = re.compile( + r""" +( + "(?P<key> (\\ . | [^"])* )" # Quoted key +) +[ ]* => [ ]* # Pair operator, optional adjoining whitespace +( + (?P<value_null> NULL ) # NULL value + | "(?P<value> (\\ . | [^"])* )" # Quoted value +) +""", + re.VERBOSE, +) + +HSTORE_DELIMITER_RE = re.compile( + r""" +[ ]* , [ ]* +""", + re.VERBOSE, +) + + +def _parse_error(hstore_str, pos): + """format an unmarshalling error.""" + + ctx = 20 + hslen = len(hstore_str) + + parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)] + residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)] + + if len(parsed_tail) > ctx: + parsed_tail = "[...]" + parsed_tail[1:] + if len(residual) > ctx: + residual = residual[:-1] + "[...]" + + return "After %r, could not parse residual at position %d: %r" % ( + parsed_tail, + pos, + residual, + ) + + +def _parse_hstore(hstore_str): + """Parse an hstore from its literal string representation. + + Attempts to approximate PG's hstore input parsing rules as closely as + possible. Although currently this is not strictly necessary, since the + current implementation of hstore's output syntax is stricter than what it + accepts as input, the documentation makes no guarantees that will always + be the case. + + + + """ + result = {} + pos = 0 + pair_match = HSTORE_PAIR_RE.match(hstore_str) + + while pair_match is not None: + key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\") + if pair_match.group("value_null"): + value = None + else: + value = ( + pair_match.group("value") + .replace(r"\"", '"') + .replace("\\\\", "\\") + ) + result[key] = value + + pos += pair_match.end() + + delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:]) + if delim_match is not None: + pos += delim_match.end() + + pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:]) + + if pos != len(hstore_str): + raise ValueError(_parse_error(hstore_str, pos)) + + return result + + +def _serialize_hstore(val): + """Serialize a dictionary into an hstore literal. Keys and values must + both be strings (except None for values). + + """ + + def esc(s, position): + if position == "value" and s is None: + return "NULL" + elif isinstance(s, util.string_types): + return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"") + else: + raise ValueError( + "%r in %s position is not a string." % (s, position) + ) + + return ", ".join( + "%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items() + ) diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py new file mode 100644 index 0000000..daaaeac --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -0,0 +1,327 @@ +# postgresql/json.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import absolute_import + +from ... import types as sqltypes +from ... import util +from ...sql import operators + + +__all__ = ("JSON", "JSONB") + +idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] + +ASTEXT = operators.custom_op( + "->>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +JSONPATH_ASTEXT = operators.custom_op( + "#>>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + + +HAS_KEY = operators.custom_op( + "?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_ALL = operators.custom_op( + "?&", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +HAS_ANY = operators.custom_op( + "?|", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +CONTAINS = operators.custom_op( + "@>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +CONTAINED_BY = operators.custom_op( + "<@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + + +class JSONPathType(sqltypes.JSON.JSONPathType): + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + assert isinstance(value, util.collections_abc.Sequence) + tokens = [util.text_type(elem) for elem in value] + value = "{%s}" % (", ".join(tokens)) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + assert isinstance(value, util.collections_abc.Sequence) + tokens = [util.text_type(elem) for elem in value] + value = "{%s}" % (", ".join(tokens)) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSON(sqltypes.JSON): + """Represent the PostgreSQL JSON type. + + :class:`_postgresql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a PostgreSQL backend, + however base :class:`_types.JSON` datatype does not provide Python + accessors for PostgreSQL-specific comparison methods such as + :meth:`_postgresql.JSON.Comparator.astext`; additionally, to use + PostgreSQL ``JSONB``, the :class:`_postgresql.JSONB` datatype should + be used explicitly. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The operators provided by the PostgreSQL version of :class:`_types.JSON` + include: + + * Index operations (the ``->`` operator):: + + data_table.c.data['some key'] + + data_table.c.data[5] + + + * Index operations returning text (the ``->>`` operator):: + + data_table.c.data['some key'].astext == 'some value' + + Note that equivalent functionality is available via the + :attr:`.JSON.Comparator.as_string` accessor. + + * Index operations with CAST + (equivalent to ``CAST(col ->> ['some key'] AS <type>)``):: + + data_table.c.data['some key'].astext.cast(Integer) == 5 + + Note that equivalent functionality is available via the + :attr:`.JSON.Comparator.as_integer` and similar accessors. + + * Path index operations (the ``#>`` operator):: + + data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')] + + * Path index operations returning text (the ``#>>`` operator):: + + data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value' + + .. versionchanged:: 1.1 The :meth:`_expression.ColumnElement.cast` + operator on + JSON objects now requires that the :attr:`.JSON.Comparator.astext` + modifier be called explicitly, if the cast works only from a textual + string. + + Index operations return an expression object whose type defaults to + :class:`_types.JSON` by default, + so that further JSON-oriented instructions + may be called upon the result type. + + Custom serializers and deserializers are specified at the dialect level, + that is using :func:`_sa.create_engine`. The reason for this is that when + using psycopg2, the DBAPI only allows serializers at the per-cursor + or per-connection level. E.g.:: + + engine = create_engine("postgresql://scott:tiger@localhost/test", + json_serializer=my_serialize_fn, + json_deserializer=my_deserialize_fn + ) + + When using the psycopg2 dialect, the json_deserializer is registered + against the database using ``psycopg2.extras.register_default_json``. + + .. seealso:: + + :class:`_types.JSON` - Core level JSON type + + :class:`_postgresql.JSONB` + + .. versionchanged:: 1.1 :class:`_postgresql.JSON` is now a PostgreSQL- + specific specialization of the new :class:`_types.JSON` type. + + """ # noqa + + astext_type = sqltypes.Text() + + def __init__(self, none_as_null=False, astext_type=None): + """Construct a :class:`_types.JSON` type. + + :param none_as_null: if True, persist the value ``None`` as a + SQL NULL value, not the JSON encoding of ``null``. Note that + when this flag is False, the :func:`.null` construct can still + be used to persist a NULL value:: + + from sqlalchemy import null + conn.execute(table.insert(), data=null()) + + .. versionchanged:: 0.9.8 - Added ``none_as_null``, and :func:`.null` + is now supported in order to persist a NULL value. + + .. seealso:: + + :attr:`_types.JSON.NULL` + + :param astext_type: the type to use for the + :attr:`.JSON.Comparator.astext` + accessor on indexed attributes. Defaults to :class:`_types.Text`. + + .. versionadded:: 1.1 + + """ + super(JSON, self).__init__(none_as_null=none_as_null) + if astext_type is not None: + self.astext_type = astext_type + + class Comparator(sqltypes.JSON.Comparator): + """Define comparison operations for :class:`_types.JSON`.""" + + @property + def astext(self): + """On an indexed expression, use the "astext" (e.g. "->>") + conversion when rendered in SQL. + + E.g.:: + + select(data_table.c.data['some key'].astext) + + .. seealso:: + + :meth:`_expression.ColumnElement.cast` + + """ + if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType): + return self.expr.left.operate( + JSONPATH_ASTEXT, + self.expr.right, + result_type=self.type.astext_type, + ) + else: + return self.expr.left.operate( + ASTEXT, self.expr.right, result_type=self.type.astext_type + ) + + comparator_factory = Comparator + + +class JSONB(JSON): + """Represent the PostgreSQL JSONB type. + + The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data, + e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', JSONB) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + The :class:`_postgresql.JSONB` type includes all operations provided by + :class:`_types.JSON`, including the same behaviors for indexing + operations. + It also adds additional operators specific to JSONB, including + :meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`, + :meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`, + and :meth:`.JSONB.Comparator.contained_by`. + + Like the :class:`_types.JSON` type, the :class:`_postgresql.JSONB` + type does not detect + in-place changes when used with the ORM, unless the + :mod:`sqlalchemy.ext.mutable` extension is used. + + Custom serializers and deserializers + are shared with the :class:`_types.JSON` class, + using the ``json_serializer`` + and ``json_deserializer`` keyword arguments. These must be specified + at the dialect level using :func:`_sa.create_engine`. When using + psycopg2, the serializers are associated with the jsonb type using + ``psycopg2.extras.register_default_jsonb`` on a per-connection basis, + in the same way that ``psycopg2.extras.register_default_json`` is used + to register these handlers with the json type. + + .. versionadded:: 0.9.7 + + .. seealso:: + + :class:`_types.JSON` + + """ + + __visit_name__ = "JSONB" + + class Comparator(JSON.Comparator): + """Define comparison operations for :class:`_types.JSON`.""" + + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. + """ + return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) + + def has_all(self, other): + """Boolean expression. Test for presence of all keys in jsonb""" + return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) + + def has_any(self, other): + """Boolean expression. Test for presence of any key in jsonb""" + return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) + + def contains(self, other, **kwargs): + """Boolean expression. Test if keys (or array) are a superset + of/contained the keys of the argument jsonb expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if keys are a proper subset of the + keys of the argument jsonb expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + comparator_factory = Comparator diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py new file mode 100644 index 0000000..98561a9 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -0,0 +1,594 @@ +# postgresql/pg8000.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS +# file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: postgresql+pg8000 + :name: pg8000 + :dbapi: pg8000 + :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/pg8000/ + +.. versionchanged:: 1.4 The pg8000 dialect has been updated for version + 1.16.6 and higher, and is again part of SQLAlchemy's continuous integration + with full feature support. + +.. _pg8000_unicode: + +Unicode +------- + +pg8000 will encode / decode string values between it and the server using the +PostgreSQL ``client_encoding`` parameter; by default this is the value in +the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. +Typically, this can be changed to ``utf-8``, as a more useful default:: + + #client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +The ``client_encoding`` can be overridden for a session by executing the SQL: + +SET CLIENT_ENCODING TO 'utf8'; + +SQLAlchemy will execute this SQL on all new connections based on the value +passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter:: + + engine = create_engine( + "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8') + +.. _pg8000_ssl: + +SSL Connections +--------------- + +pg8000 accepts a Python ``SSLContext`` object which may be specified using the +:paramref:`_sa.create_engine.connect_args` dictionary:: + + import ssl + ssl_context = ssl.create_default_context() + engine = sa.create_engine( + "postgresql+pg8000://scott:tiger@192.168.0.199/test", + connect_args={"ssl_context": ssl_context}, + ) + +If the server uses an automatically-generated certificate that is self-signed +or does not match the host name (as seen from the client), it may also be +necessary to disable hostname checking:: + + import ssl + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + engine = sa.create_engine( + "postgresql+pg8000://scott:tiger@192.168.0.199/test", + connect_args={"ssl_context": ssl_context}, + ) + +.. _pg8000_isolation_level: + +pg8000 Transaction Isolation Level +------------------------------------- + +The pg8000 dialect offers the same isolation level settings as that +of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`postgresql_isolation_level` + + :ref:`psycopg2_isolation_level` + + +""" # noqa +import decimal +import re +from uuid import UUID as _python_UUID + +from .array import ARRAY as PGARRAY +from .base import _ColonCast +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import ENUM +from .base import INTERVAL +from .base import PGCompiler +from .base import PGDialect +from .base import PGExecutionContext +from .base import PGIdentifierPreparer +from .base import UUID +from .json import JSON +from .json import JSONB +from .json import JSONPathType +from ... import exc +from ... import processors +from ... import types as sqltypes +from ... import util +from ...sql.elements import quoted_name + + +class _PGNumeric(sqltypes.Numeric): + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PGNumericNoBind(_PGNumeric): + def bind_processor(self, dialect): + return None + + +class _PGJSON(JSON): + def result_processor(self, dialect, coltype): + return None + + def get_dbapi_type(self, dbapi): + return dbapi.JSON + + +class _PGJSONB(JSONB): + def result_processor(self, dialect, coltype): + return None + + def get_dbapi_type(self, dbapi): + return dbapi.JSONB + + +class _PGJSONIndexType(sqltypes.JSON.JSONIndexType): + def get_dbapi_type(self, dbapi): + raise NotImplementedError("should not be here") + + +class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.STRING + + +class _PGJSONPathType(JSONPathType): + def get_dbapi_type(self, dbapi): + return 1009 + + +class _PGUUID(UUID): + def bind_processor(self, dialect): + if not self.as_uuid: + + def process(value): + if value is not None: + value = _python_UUID(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not self.as_uuid: + + def process(value): + if value is not None: + value = str(value) + return value + + return process + + +class _PGEnum(ENUM): + def get_dbapi_type(self, dbapi): + return dbapi.UNKNOWN + + +class _PGInterval(INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return _PGInterval(precision=interval.second_precision) + + +class _PGTimeStamp(sqltypes.DateTime): + def get_dbapi_type(self, dbapi): + if self.timezone: + # TIMESTAMPTZOID + return 1184 + else: + # TIMESTAMPOID + return 1114 + + +class _PGTime(sqltypes.Time): + def get_dbapi_type(self, dbapi): + return dbapi.TIME + + +class _PGInteger(sqltypes.Integer): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class _PGSmallInteger(sqltypes.SmallInteger): + def get_dbapi_type(self, dbapi): + return dbapi.INTEGER + + +class _PGNullType(sqltypes.NullType): + def get_dbapi_type(self, dbapi): + return dbapi.NULLTYPE + + +class _PGBigInteger(sqltypes.BigInteger): + def get_dbapi_type(self, dbapi): + return dbapi.BIGINTEGER + + +class _PGBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.BOOLEAN + + +class _PGARRAY(PGARRAY): + def bind_expression(self, bindvalue): + return _ColonCast(bindvalue, self) + + +_server_side_id = util.counter() + + +class PGExecutionContext_pg8000(PGExecutionContext): + def create_server_side_cursor(self): + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return ServerSideCursor(self._dbapi_connection.cursor(), ident) + + def pre_exec(self): + if not self.compiled: + return + + +class ServerSideCursor: + server_side = True + + def __init__(self, cursor, ident): + self.ident = ident + self.cursor = cursor + + @property + def connection(self): + return self.cursor.connection + + @property + def rowcount(self): + return self.cursor.rowcount + + @property + def description(self): + return self.cursor.description + + def execute(self, operation, args=(), stream=None): + op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation + self.cursor.execute(op, args, stream=stream) + return self + + def executemany(self, operation, param_sets): + self.cursor.executemany(operation, param_sets) + return self + + def fetchone(self): + self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident) + return self.cursor.fetchone() + + def fetchmany(self, num=None): + if num is None: + return self.fetchall() + else: + self.cursor.execute( + "FETCH FORWARD " + str(int(num)) + " FROM " + self.ident + ) + return self.cursor.fetchall() + + def fetchall(self): + self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident) + return self.cursor.fetchall() + + def close(self): + self.cursor.execute("CLOSE " + self.ident) + self.cursor.close() + + def setinputsizes(self, *sizes): + self.cursor.setinputsizes(*sizes) + + def setoutputsize(self, size, column=None): + pass + + +class PGCompiler_pg8000(PGCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) + + +class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): + def __init__(self, *args, **kwargs): + PGIdentifierPreparer.__init__(self, *args, **kwargs) + self._double_percents = False + + +class PGDialect_pg8000(PGDialect): + driver = "pg8000" + supports_statement_cache = True + + supports_unicode_statements = True + + supports_unicode_binds = True + + default_paramstyle = "format" + supports_sane_multi_rowcount = True + execution_ctx_cls = PGExecutionContext_pg8000 + statement_compiler = PGCompiler_pg8000 + preparer = PGIdentifierPreparer_pg8000 + supports_server_side_cursors = True + + use_setinputsizes = True + + # reversed as of pg8000 1.16.6. 1.16.5 and lower + # are no longer compatible + description_encoding = None + # description_encoding = "use_encoding" + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: _PGNumericNoBind, + sqltypes.Float: _PGNumeric, + sqltypes.JSON: _PGJSON, + sqltypes.Boolean: _PGBoolean, + sqltypes.NullType: _PGNullType, + JSONB: _PGJSONB, + sqltypes.JSON.JSONPathType: _PGJSONPathType, + sqltypes.JSON.JSONIndexType: _PGJSONIndexType, + sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType, + UUID: _PGUUID, + sqltypes.Interval: _PGInterval, + INTERVAL: _PGInterval, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Time: _PGTime, + sqltypes.Integer: _PGInteger, + sqltypes.SmallInteger: _PGSmallInteger, + sqltypes.BigInteger: _PGBigInteger, + sqltypes.Enum: _PGEnum, + sqltypes.ARRAY: _PGARRAY, + }, + ) + + def __init__(self, client_encoding=None, **kwargs): + PGDialect.__init__(self, **kwargs) + self.client_encoding = client_encoding + + if self._dbapi_version < (1, 16, 6): + raise NotImplementedError("pg8000 1.16.6 or greater is required") + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + @classmethod + def dbapi(cls): + return __import__("pg8000") + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.InterfaceError) and "network error" in str( + e + ): + # new as of pg8000 1.19.0 for broken connections + return True + + # connection was closed normally + return "connection is closed" in str(e) + + def set_isolation_level(self, connection, level): + level = level.replace("_", " ") + + # adjust for ConnectionFairy possibly being present + if hasattr(connection, "dbapi_connection"): + connection = connection.dbapi_connection + + if level == "AUTOCOMMIT": + connection.autocommit = True + elif level in self._isolation_lookup: + connection.autocommit = False + cursor = connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + "ISOLATION LEVEL %s" % level + ) + cursor.execute("COMMIT") + cursor.close() + else: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s or AUTOCOMMIT" + % (level, self.name, ", ".join(self._isolation_lookup)) + ) + + def set_readonly(self, connection, value): + cursor = connection.cursor() + try: + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION %s" + % ("READ ONLY" if value else "READ WRITE") + ) + cursor.execute("COMMIT") + finally: + cursor.close() + + def get_readonly(self, connection): + cursor = connection.cursor() + try: + cursor.execute("show transaction_read_only") + val = cursor.fetchone()[0] + finally: + cursor.close() + + return val == "on" + + def set_deferrable(self, connection, value): + cursor = connection.cursor() + try: + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION %s" + % ("DEFERRABLE" if value else "NOT DEFERRABLE") + ) + cursor.execute("COMMIT") + finally: + cursor.close() + + def get_deferrable(self, connection): + cursor = connection.cursor() + try: + cursor.execute("show transaction_deferrable") + val = cursor.fetchone()[0] + finally: + cursor.close() + + return val == "on" + + def set_client_encoding(self, connection, client_encoding): + # adjust for ConnectionFairy possibly being present + if hasattr(connection, "dbapi_connection"): + connection = connection.dbapi_connection + + cursor = connection.cursor() + cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'") + cursor.execute("COMMIT") + cursor.close() + + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + cursor.setinputsizes( + **{ + key: dbtype + for key, dbtype, sqltype in list_of_tuples + if dbtype + } + ) + + def do_begin_twophase(self, connection, xid): + connection.connection.tpc_begin((0, xid, "")) + + def do_prepare_twophase(self, connection, xid): + connection.connection.tpc_prepare() + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_rollback((0, xid, "")) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_commit((0, xid, "")) + + def do_recover_twophase(self, connection): + return [row[1] for row in connection.connection.tpc_recover()] + + def on_connect(self): + fns = [] + + def on_connect(conn): + conn.py_types[quoted_name] = conn.py_types[util.text_type] + + fns.append(on_connect) + + if self.client_encoding is not None: + + def on_connect(conn): + self.set_client_encoding(conn, self.client_encoding) + + fns.append(on_connect) + + if self.isolation_level is not None: + + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + fns.append(on_connect) + + if self._json_deserializer: + + def on_connect(conn): + # json + conn.register_in_adapter(114, self._json_deserializer) + + # jsonb + conn.register_in_adapter(3802, self._json_deserializer) + + fns.append(on_connect) + + if len(fns) > 0: + + def on_connect(conn): + for fn in fns: + fn(conn) + + return on_connect + else: + return None + + +dialect = PGDialect_pg8000 diff --git a/lib/sqlalchemy/dialects/postgresql/provision.py b/lib/sqlalchemy/dialects/postgresql/provision.py new file mode 100644 index 0000000..98470f3 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/provision.py @@ -0,0 +1,124 @@ +import time + +from ... import exc +from ... import inspect +from ... import text +from ...testing import warn_test_suite +from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables +from ...testing.provision import drop_db +from ...testing.provision import log +from ...testing.provision import prepare_for_drop_tables +from ...testing.provision import set_default_schema_on_connection +from ...testing.provision import temp_table_keyword_args + + +@create_db.for_db("postgresql") +def _pg_create_db(cfg, eng, ident): + template_db = cfg.options.postgresql_templatedb + + with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn: + + if not template_db: + template_db = conn.exec_driver_sql( + "select current_database()" + ).scalar() + + attempt = 0 + while True: + try: + conn.exec_driver_sql( + "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db) + ) + except exc.OperationalError as err: + attempt += 1 + if attempt >= 3: + raise + if "accessed by other users" in str(err): + log.info( + "Waiting to create %s, URI %r, " + "template DB %s is in use sleeping for .5", + ident, + eng.url, + template_db, + ) + time.sleep(0.5) + except: + raise + else: + break + + +@drop_db.for_db("postgresql") +def _pg_drop_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + with conn.begin(): + conn.execute( + text( + "select pg_terminate_backend(pid) from pg_stat_activity " + "where usename=current_user and pid != pg_backend_pid() " + "and datname=:dname" + ), + dict(dname=ident), + ) + conn.exec_driver_sql("DROP DATABASE %s" % ident) + + +@temp_table_keyword_args.for_db("postgresql") +def _postgresql_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} + + +@set_default_schema_on_connection.for_db("postgresql") +def _postgresql_set_default_schema_on_connection( + cfg, dbapi_connection, schema_name +): + existing_autocommit = dbapi_connection.autocommit + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + cursor.execute("SET SESSION search_path='%s'" % schema_name) + cursor.close() + dbapi_connection.autocommit = existing_autocommit + + +@drop_all_schema_objects_pre_tables.for_db("postgresql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + for xid in conn.execute("select gid from pg_prepared_xacts").scalars(): + conn.execute("ROLLBACK PREPARED '%s'" % xid) + + +@drop_all_schema_objects_post_tables.for_db("postgresql") +def drop_all_schema_objects_post_tables(cfg, eng): + from sqlalchemy.dialects import postgresql + + inspector = inspect(eng) + with eng.begin() as conn: + for enum in inspector.get_enums("*"): + conn.execute( + postgresql.DropEnumType( + postgresql.ENUM(name=enum["name"], schema=enum["schema"]) + ) + ) + + +@prepare_for_drop_tables.for_db("postgresql") +def prepare_for_drop_tables(config, connection): + """Ensure there are no locks on the current username/database.""" + + result = connection.exec_driver_sql( + "select pid, state, wait_event_type, query " + # "select pg_terminate_backend(pid), state, wait_event_type " + "from pg_stat_activity where " + "usename=current_user " + "and datname=current_database() and state='idle in transaction' " + "and pid != pg_backend_pid()" + ) + rows = result.all() # noqa + if rows: + warn_test_suite( + "PostgreSQL may not be able to DROP tables due to " + "idle in transaction: %s" + % ("; ".join(row._mapping["query"] for row in rows)) + ) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py new file mode 100644 index 0000000..6747427 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -0,0 +1,1088 @@ +# postgresql/psycopg2.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: postgresql+psycopg2 + :name: psycopg2 + :dbapi: psycopg2 + :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg2/ + +psycopg2 Connect Arguments +-------------------------- + +Keyword arguments that are specific to the SQLAlchemy psycopg2 dialect +may be passed to :func:`_sa.create_engine()`, and include the following: + + +* ``isolation_level``: This option, available for all PostgreSQL dialects, + includes the ``AUTOCOMMIT`` isolation level when using the psycopg2 + dialect. This option sets the **default** isolation level for the + connection that is set immediately upon connection to the database before + the connection is pooled. This option is generally superseded by the more + modern :paramref:`_engine.Connection.execution_options.isolation_level` + execution option, detailed at :ref:`dbapi_autocommit`. + + .. seealso:: + + :ref:`psycopg2_isolation_level` + + :ref:`dbapi_autocommit` + + +* ``client_encoding``: sets the client encoding in a libpq-agnostic way, + using psycopg2's ``set_client_encoding()`` method. + + .. seealso:: + + :ref:`psycopg2_unicode` + +* ``use_native_unicode``: Under Python 2 only, this can be set to False to + disable the use of psycopg2's native Unicode support. + + .. seealso:: + + :ref:`psycopg2_disable_native_unicode` + + +* ``executemany_mode``, ``executemany_batch_page_size``, + ``executemany_values_page_size``: Allows use of psycopg2 + extensions for optimizing "executemany"-style queries. See the referenced + section below for details. + + .. seealso:: + + :ref:`psycopg2_executemany_mode` + +.. tip:: + + The above keyword arguments are **dialect** keyword arguments, meaning + that they are passed as explicit keyword arguments to :func:`_sa.create_engine()`:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@localhost/test", + isolation_level="SERIALIZABLE", + ) + + These should not be confused with **DBAPI** connect arguments, which + are passed as part of the :paramref:`_sa.create_engine.connect_args` + dictionary and/or are passed in the URL query string, as detailed in + the section :ref:`custom_dbapi_args`. + +.. _psycopg2_ssl: + +SSL Connections +--------------- + +The psycopg2 module has a connection argument named ``sslmode`` for +controlling its behavior regarding secure (SSL) connections. The default is +``sslmode=prefer``; it will attempt an SSL connection and if that fails it +will fall back to an unencrypted connection. ``sslmode=require`` may be used +to ensure that only secure connections are established. Consult the +psycopg2 / libpq documentation for further options that are available. + +Note that ``sslmode`` is specific to psycopg2 so it is included in the +connection URI:: + + engine = sa.create_engine( + "postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require" + ) + +Unix Domain Connections +------------------------ + +psycopg2 supports connecting via Unix domain connections. When the ``host`` +portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2, +which specifies Unix-domain communication rather than TCP/IP communication:: + + create_engine("postgresql+psycopg2://user:password@/dbname") + +By default, the socket file used is to connect to a Unix-domain socket +in ``/tmp``, or whatever socket directory was specified when PostgreSQL +was built. This value can be overridden by passing a pathname to psycopg2, +using ``host`` as an additional keyword argument:: + + create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + +.. seealso:: + + `PQconnectdbParams \ + <https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_ + +.. _psycopg2_multi_host: + +Specifying multiple fallback hosts +----------------------------------- + +psycopg2 supports multiple connection points in the connection string. +When the ``host`` parameter is used multiple times in the query section of +the URL, SQLAlchemy will create a single string of the host and port +information provided to make the connections. Tokens may consist of +``host::port`` or just ``host``; in the latter case, the default port +is selected by libpq. In the example below, three host connections +are specified, for ``HostA::PortA``, ``HostB`` connecting to the default port, +and ``HostC::PortC``:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC" + ) + +As an alternative, libpq query string format also may be used; this specifies +``host`` and ``port`` as single query string arguments with comma-separated +lists - the default port can be chosen by indicating an empty value +in the comma separated list:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA,HostB,HostC&port=PortA,,PortC" + ) + +With either URL style, connections to each host is attempted based on a +configurable strategy, which may be configured using the libpq +``target_session_attrs`` parameter. Per libpq this defaults to ``any`` +which indicates a connection to each host is then attempted until a connection is successful. +Other strategies include ``primary``, ``prefer-standby``, etc. The complete +list is documented by PostgreSQL at +`libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_. + +For example, to indicate two hosts using the ``primary`` strategy:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC&target_session_attrs=primary" + ) + +.. versionchanged:: 1.4.40 Port specification in psycopg2 multiple host format + is repaired, previously ports were not correctly interpreted in this context. + libpq comma-separated format is also now supported. + +.. versionadded:: 1.3.20 Support for multiple hosts in PostgreSQL connection + string. + +.. seealso:: + + `libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_ - please refer + to this section in the libpq documentation for complete background on multiple host support. + + +Empty DSN Connections / Environment Variable Connections +--------------------------------------------------------- + +The psycopg2 DBAPI can connect to PostgreSQL by passing an empty DSN to the +libpq client library, which by default indicates to connect to a localhost +PostgreSQL database that is open for "trust" connections. This behavior can be +further tailored using a particular set of environment variables which are +prefixed with ``PG_...``, which are consumed by ``libpq`` to take the place of +any or all elements of the connection string. + +For this form, the URL can be passed without any elements other than the +initial scheme:: + + engine = create_engine('postgresql+psycopg2://') + +In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()`` +function which in turn represents an empty DSN passed to libpq. + +.. versionadded:: 1.3.2 support for parameter-less connections with psycopg2. + +.. seealso:: + + `Environment Variables\ + <https://www.postgresql.org/docs/current/libpq-envars.html>`_ - + PostgreSQL documentation on how to use ``PG_...`` + environment variables for connections. + +.. _psycopg2_execution_options: + +Per-Statement/Connection Execution Options +------------------------------------------- + +The following DBAPI-specific options are respected when used with +:meth:`_engine.Connection.execution_options`, +:meth:`.Executable.execution_options`, +:meth:`_query.Query.execution_options`, +in addition to those not specific to DBAPIs: + +* ``isolation_level`` - Set the transaction isolation level for the lifespan + of a :class:`_engine.Connection` (can only be set on a connection, + not a statement + or query). See :ref:`psycopg2_isolation_level`. + +* ``stream_results`` - Enable or disable usage of psycopg2 server side + cursors - this feature makes use of "named" cursors in combination with + special result handling methods so that result rows are not fully buffered. + Defaults to False, meaning cursors are buffered by default. + +* ``max_row_buffer`` - when using ``stream_results``, an integer value that + specifies the maximum number of rows to buffer at a time. This is + interpreted by the :class:`.BufferedRowCursorResult`, and if omitted the + buffer will grow to ultimately store 1000 rows at a time. + + .. versionchanged:: 1.4 The ``max_row_buffer`` size can now be greater than + 1000, and the buffer will grow to that size. + +.. _psycopg2_batch_mode: + +.. _psycopg2_executemany_mode: + +Psycopg2 Fast Execution Helpers +------------------------------- + +Modern versions of psycopg2 include a feature known as +`Fast Execution Helpers \ +<https://initd.org/psycopg/docs/extras.html#fast-execution-helpers>`_, which +have been shown in benchmarking to improve psycopg2's executemany() +performance, primarily with INSERT statements, by multiple orders of magnitude. +SQLAlchemy internally makes use of these extensions for ``executemany()`` style +calls, which correspond to lists of parameters being passed to +:meth:`_engine.Connection.execute` as detailed in :ref:`multiple parameter +sets <tutorial_multiple_parameters>`. The ORM also uses this mode internally whenever +possible. + +The two available extensions on the psycopg2 side are the ``execute_values()`` +and ``execute_batch()`` functions. The psycopg2 dialect defaults to using the +``execute_values()`` extension for all qualifying INSERT statements. + +.. versionchanged:: 1.4 The psycopg2 dialect now defaults to a new mode + ``"values_only"`` for ``executemany_mode``, which allows an order of + magnitude performance improvement for INSERT statements, but does not + include "batch" mode for UPDATE and DELETE statements which removes the + ability of ``cursor.rowcount`` to function correctly. + +The use of these extensions is controlled by the ``executemany_mode`` flag +which may be passed to :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@host/dbname", + executemany_mode='values_plus_batch') + + +Possible options for ``executemany_mode`` include: + +* ``values_only`` - this is the default value. the psycopg2 execute_values() + extension is used for qualifying INSERT statements, which rewrites the INSERT + to include multiple VALUES clauses so that many parameter sets can be + inserted with one statement. + + .. versionadded:: 1.4 Added ``"values_only"`` setting for ``executemany_mode`` + which is also now the default. + +* ``None`` - No psycopg2 extensions are not used, and the usual + ``cursor.executemany()`` method is used when invoking statements with + multiple parameter sets. + +* ``'batch'`` - Uses ``psycopg2.extras.execute_batch`` for all qualifying + INSERT, UPDATE and DELETE statements, so that multiple copies + of a SQL query, each one corresponding to a parameter set passed to + ``executemany()``, are joined into a single SQL string separated by a + semicolon. When using this mode, the :attr:`_engine.CursorResult.rowcount` + attribute will not contain a value for executemany-style executions. + +* ``'values_plus_batch'``- ``execute_values`` is used for qualifying INSERT + statements, ``execute_batch`` is used for UPDATE and DELETE. + When using this mode, the :attr:`_engine.CursorResult.rowcount` + attribute will not contain a value for executemany-style executions against + UPDATE and DELETE statements. + +By "qualifying statements", we mean that the statement being executed +must be a Core :func:`_expression.insert`, :func:`_expression.update` +or :func:`_expression.delete` construct, and not a plain textual SQL +string or one constructed using :func:`_expression.text`. When using the +ORM, all insert/update/delete statements used by the ORM flush process +are qualifying. + +The "page size" for the "values" and "batch" strategies can be affected +by using the ``executemany_batch_page_size`` and +``executemany_values_page_size`` engine parameters. These +control how many parameter sets +should be represented in each execution. The "values" page size defaults +to 1000, which is different that psycopg2's default. The "batch" page +size defaults to 100. These can be affected by passing new values to +:func:`_engine.create_engine`:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@host/dbname", + executemany_mode='values', + executemany_values_page_size=10000, executemany_batch_page_size=500) + +.. versionchanged:: 1.4 + + The default for ``executemany_values_page_size`` is now 1000, up from + 100. + +.. seealso:: + + :ref:`tutorial_multiple_parameters` - General information on using the + :class:`_engine.Connection` + object to execute statements in such a way as to make + use of the DBAPI ``.executemany()`` method. + + +.. _psycopg2_unicode: + +Unicode with Psycopg2 +---------------------- + +The psycopg2 DBAPI driver supports Unicode data transparently. Under Python 2 +only, the SQLAlchemy psycopg2 dialect will enable the +``psycopg2.extensions.UNICODE`` extension by default to ensure Unicode is +handled properly; under Python 3, this is psycopg2's default behavior. + +The client character encoding can be controlled for the psycopg2 dialect +in the following ways: + +* For PostgreSQL 9.1 and above, the ``client_encoding`` parameter may be + passed in the database URL; this parameter is consumed by the underlying + ``libpq`` PostgreSQL client library:: + + engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8") + + Alternatively, the above ``client_encoding`` value may be passed using + :paramref:`_sa.create_engine.connect_args` for programmatic establishment with + ``libpq``:: + + engine = create_engine( + "postgresql+psycopg2://user:pass@host/dbname", + connect_args={'client_encoding': 'utf8'} + ) + +* For all PostgreSQL versions, psycopg2 supports a client-side encoding + value that will be passed to database connections when they are first + established. The SQLAlchemy psycopg2 dialect supports this using the + ``client_encoding`` parameter passed to :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+psycopg2://user:pass@host/dbname", + client_encoding="utf8" + ) + + .. tip:: The above ``client_encoding`` parameter admittedly is very similar + in appearance to usage of the parameter within the + :paramref:`_sa.create_engine.connect_args` dictionary; the difference + above is that the parameter is consumed by psycopg2 and is + passed to the database connection using ``SET client_encoding TO + 'utf8'``; in the previously mentioned style, the parameter is instead + passed through psycopg2 and consumed by the ``libpq`` library. + +* A common way to set up client encoding with PostgreSQL databases is to + ensure it is configured within the server-side postgresql.conf file; + this is the recommended way to set encoding for a server that is + consistently of one encoding in all databases:: + + # postgresql.conf file + + # client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +.. _psycopg2_disable_native_unicode: + +Disabling Native Unicode +^^^^^^^^^^^^^^^^^^^^^^^^ + +Under Python 2 only, SQLAlchemy can also be instructed to skip the usage of the +psycopg2 ``UNICODE`` extension and to instead utilize its own unicode +encode/decode services, which are normally reserved only for those DBAPIs that +don't fully support unicode directly. Passing ``use_native_unicode=False`` to +:func:`_sa.create_engine` will disable usage of ``psycopg2.extensions. +UNICODE``. SQLAlchemy will instead encode data itself into Python bytestrings +on the way in and coerce from bytes on the way back, using the value of the +:func:`_sa.create_engine` ``encoding`` parameter, which defaults to ``utf-8``. +SQLAlchemy's own unicode encode/decode functionality is steadily becoming +obsolete as most DBAPIs now support unicode fully. + + +Transactions +------------ + +The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. + +.. _psycopg2_isolation_level: + +Psycopg2 Transaction Isolation Level +------------------------------------- + +As discussed in :ref:`postgresql_isolation_level`, +all PostgreSQL dialects support setting of transaction isolation level +both via the ``isolation_level`` parameter passed to :func:`_sa.create_engine` +, +as well as the ``isolation_level`` argument used by +:meth:`_engine.Connection.execution_options`. When using the psycopg2 dialect +, these +options make use of psycopg2's ``set_isolation_level()`` connection method, +rather than emitting a PostgreSQL directive; this is because psycopg2's +API-level setting is always emitted at the start of each transaction in any +case. + +The psycopg2 dialect supports these constants for isolation level: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`postgresql_isolation_level` + + :ref:`pg8000_isolation_level` + + +NOTICE logging +--------------- + +The psycopg2 dialect will log PostgreSQL NOTICE messages +via the ``sqlalchemy.dialects.postgresql`` logger. When this logger +is set to the ``logging.INFO`` level, notice messages will be logged:: + + import logging + + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + +Above, it is assumed that logging is configured externally. If this is not +the case, configuration such as ``logging.basicConfig()`` must be utilized:: + + import logging + + logging.basicConfig() # log messages to stdout + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + +.. seealso:: + + `Logging HOWTO <https://docs.python.org/3/howto/logging.html>`_ - on the python.org website + +.. _psycopg2_hstore: + +HSTORE type +------------ + +The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of +the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension +by default when psycopg2 version 2.4 or greater is used, and +it is detected that the target database has the HSTORE type set up for use. +In other words, when the dialect makes the first +connection, a sequence like the following is performed: + +1. Request the available HSTORE oids using + ``psycopg2.extras.HstoreAdapter.get_oids()``. + If this function returns a list of HSTORE identifiers, we then determine + that the ``HSTORE`` extension is present. + This function is **skipped** if the version of psycopg2 installed is + less than version 2.4. + +2. If the ``use_native_hstore`` flag is at its default of ``True``, and + we've detected that ``HSTORE`` oids are available, the + ``psycopg2.extensions.register_hstore()`` extension is invoked for all + connections. + +The ``register_hstore()`` extension has the effect of **all Python +dictionaries being accepted as parameters regardless of the type of target +column in SQL**. The dictionaries are converted by this extension into a +textual HSTORE expression. If this behavior is not desired, disable the +use of the hstore extension by setting ``use_native_hstore`` to ``False`` as +follows:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + use_native_hstore=False) + +The ``HSTORE`` type is **still supported** when the +``psycopg2.extensions.register_hstore()`` extension is not used. It merely +means that the coercion between Python dictionaries and the HSTORE +string format, on both the parameter side and the result side, will take +place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2`` +which may be more performant. + +""" # noqa +from __future__ import absolute_import + +import decimal +import logging +import re +from uuid import UUID as _python_UUID + +from .array import ARRAY as PGARRAY +from .base import _ColonCast +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import ENUM +from .base import PGCompiler +from .base import PGDialect +from .base import PGExecutionContext +from .base import PGIdentifierPreparer +from .base import UUID +from .hstore import HSTORE +from .json import JSON +from .json import JSONB +from ... import exc +from ... import processors +from ... import types as sqltypes +from ... import util +from ...engine import cursor as _cursor +from ...util import collections_abc + + +logger = logging.getLogger("sqlalchemy.dialects.postgresql") + + +class _PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PGEnum(ENUM): + def result_processor(self, dialect, coltype): + if util.py2k and self._expect_unicode is True: + # for py2k, if the enum type needs unicode data (which is set up as + # part of the Enum() constructor based on values passed as py2k + # unicode objects) we have to use our own converters since + # psycopg2's don't work, a rare exception to the "modern DBAPIs + # support unicode everywhere" theme of deprecating + # convert_unicode=True. Use the special "force_nocheck" directive + # which forces unicode conversion to happen on the Python side + # without an isinstance() check. in py3k psycopg2 does the right + # thing automatically. + self._expect_unicode = "force_nocheck" + return super(_PGEnum, self).result_processor(dialect, coltype) + + +class _PGHStore(HSTORE): + def bind_processor(self, dialect): + if dialect._has_native_hstore: + return None + else: + return super(_PGHStore, self).bind_processor(dialect) + + def result_processor(self, dialect, coltype): + if dialect._has_native_hstore: + return None + else: + return super(_PGHStore, self).result_processor(dialect, coltype) + + +class _PGARRAY(PGARRAY): + def bind_expression(self, bindvalue): + return _ColonCast(bindvalue, self) + + +class _PGJSON(JSON): + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + def result_processor(self, dialect, coltype): + return None + + +class _PGUUID(UUID): + def bind_processor(self, dialect): + if not self.as_uuid and dialect.use_native_uuid: + + def process(value): + if value is not None: + value = _python_UUID(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not self.as_uuid and dialect.use_native_uuid: + + def process(value): + if value is not None: + value = str(value) + return value + + return process + + +_server_side_id = util.counter() + + +class PGExecutionContext_psycopg2(PGExecutionContext): + _psycopg2_fetched_rows = None + + def create_server_side_cursor(self): + # use server-side cursors: + # https://lists.initd.org/pipermail/psycopg/2007-January/005251.html + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return self._dbapi_connection.cursor(ident) + + def post_exec(self): + if ( + self._psycopg2_fetched_rows + and self.compiled + and self.compiled.returning + ): + # psycopg2 execute_values will provide for a real cursor where + # cursor.description works correctly. however, it executes the + # INSERT statement multiple times for multiple pages of rows, so + # while this cursor also supports calling .fetchall() directly, in + # order to get the list of all rows inserted across multiple pages, + # we have to retrieve the aggregated list from the execute_values() + # function directly. + strat_cls = _cursor.FullyBufferedCursorFetchStrategy + self.cursor_fetch_strategy = strat_cls( + self.cursor, initial_buffer=self._psycopg2_fetched_rows + ) + self._log_notices(self.cursor) + + def _log_notices(self, cursor): + # check also that notices is an iterable, after it's already + # established that we will be iterating through it. This is to get + # around test suites such as SQLAlchemy's using a Mock object for + # cursor + if not cursor.connection.notices or not isinstance( + cursor.connection.notices, collections_abc.Iterable + ): + return + + for notice in cursor.connection.notices: + # NOTICE messages have a + # newline character at the end + logger.info(notice.rstrip()) + + cursor.connection.notices[:] = [] + + +class PGCompiler_psycopg2(PGCompiler): + pass + + +class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): + pass + + +EXECUTEMANY_PLAIN = util.symbol("executemany_plain", canonical=0) +EXECUTEMANY_BATCH = util.symbol("executemany_batch", canonical=1) +EXECUTEMANY_VALUES = util.symbol("executemany_values", canonical=2) +EXECUTEMANY_VALUES_PLUS_BATCH = util.symbol( + "executemany_values_plus_batch", + canonical=EXECUTEMANY_BATCH | EXECUTEMANY_VALUES, +) + + +class PGDialect_psycopg2(PGDialect): + driver = "psycopg2" + + supports_statement_cache = True + + if util.py2k: + # turn off supports_unicode_statements for Python 2. psycopg2 supports + # unicode statements in Py2K. But! it does not support unicode *bound + # parameter names* because it uses the Python "%" operator to + # interpolate these into the string, and this fails. So for Py2K, we + # have to use full-on encoding for statements and parameters before + # passing to cursor.execute(). + supports_unicode_statements = False + + supports_server_side_cursors = True + + default_paramstyle = "pyformat" + # set to true based on psycopg2 version + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_psycopg2 + statement_compiler = PGCompiler_psycopg2 + preparer = PGIdentifierPreparer_psycopg2 + psycopg2_version = (0, 0) + + _has_native_hstore = True + + engine_config_types = PGDialect.engine_config_types.union( + {"use_native_unicode": util.asbool} + ) + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: _PGNumeric, + ENUM: _PGEnum, # needs force_unicode + sqltypes.Enum: _PGEnum, # needs force_unicode + HSTORE: _PGHStore, + JSON: _PGJSON, + sqltypes.JSON: _PGJSON, + JSONB: _PGJSONB, + UUID: _PGUUID, + sqltypes.ARRAY: _PGARRAY, + }, + ) + + def __init__( + self, + use_native_unicode=True, + client_encoding=None, + use_native_hstore=True, + use_native_uuid=True, + executemany_mode="values_only", + executemany_batch_page_size=100, + executemany_values_page_size=1000, + **kwargs + ): + PGDialect.__init__(self, **kwargs) + self.use_native_unicode = use_native_unicode + if not use_native_unicode and not util.py2k: + raise exc.ArgumentError( + "psycopg2 native_unicode mode is required under Python 3" + ) + if not use_native_hstore: + self._has_native_hstore = False + self.use_native_hstore = use_native_hstore + self.use_native_uuid = use_native_uuid + self.supports_unicode_binds = use_native_unicode + self.client_encoding = client_encoding + + # Parse executemany_mode argument, allowing it to be only one of the + # symbol names + self.executemany_mode = util.symbol.parse_user_argument( + executemany_mode, + { + EXECUTEMANY_PLAIN: [None], + EXECUTEMANY_BATCH: ["batch"], + EXECUTEMANY_VALUES: ["values_only"], + EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch", "values"], + }, + "executemany_mode", + ) + + if self.executemany_mode & EXECUTEMANY_VALUES: + self.insert_executemany_returning = True + + self.executemany_batch_page_size = executemany_batch_page_size + self.executemany_values_page_size = executemany_values_page_size + + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + self.psycopg2_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + + if self.psycopg2_version < (2, 7): + raise ImportError( + "psycopg2 version 2.7 or higher is required." + ) + + def initialize(self, connection): + super(PGDialect_psycopg2, self).initialize(connection) + self._has_native_hstore = ( + self.use_native_hstore + and self._hstore_oids(connection.connection) is not None + ) + + # PGDialect.initialize() checks server version for <= 8.2 and sets + # this flag to False if so + if not self.full_returning: + self.insert_executemany_returning = False + self.executemany_mode = EXECUTEMANY_PLAIN + + self.supports_sane_multi_rowcount = not ( + self.executemany_mode & EXECUTEMANY_BATCH + ) + + @classmethod + def dbapi(cls): + import psycopg2 + + return psycopg2 + + @classmethod + def _psycopg2_extensions(cls): + from psycopg2 import extensions + + return extensions + + @classmethod + def _psycopg2_extras(cls): + from psycopg2 import extras + + return extras + + @util.memoized_property + def _isolation_lookup(self): + extensions = self._psycopg2_extensions() + return { + "AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT, + "READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED, + "READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, + "REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ, + "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE, + } + + def set_isolation_level(self, connection, level): + try: + level = self._isolation_lookup[level.replace("_", " ")] + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ), + replace_context=err, + ) + + connection.set_isolation_level(level) + + def set_readonly(self, connection, value): + connection.readonly = value + + def get_readonly(self, connection): + return connection.readonly + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def do_ping(self, dbapi_connection): + cursor = None + before_autocommit = dbapi_connection.autocommit + try: + if not before_autocommit: + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + try: + cursor.execute(self._dialect_specific_select_one) + finally: + cursor.close() + if not before_autocommit and not dbapi_connection.closed: + dbapi_connection.autocommit = before_autocommit + except self.dbapi.Error as err: + if self.is_disconnect(err, dbapi_connection, cursor): + return False + else: + raise + else: + return True + + def on_connect(self): + extras = self._psycopg2_extras() + extensions = self._psycopg2_extensions() + + fns = [] + if self.client_encoding is not None: + + def on_connect(conn): + conn.set_client_encoding(self.client_encoding) + + fns.append(on_connect) + + if self.isolation_level is not None: + + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + fns.append(on_connect) + + if self.dbapi and self.use_native_uuid: + + def on_connect(conn): + extras.register_uuid(None, conn) + + fns.append(on_connect) + + if util.py2k and self.dbapi and self.use_native_unicode: + + def on_connect(conn): + extensions.register_type(extensions.UNICODE, conn) + extensions.register_type(extensions.UNICODEARRAY, conn) + + fns.append(on_connect) + + if self.dbapi and self.use_native_hstore: + + def on_connect(conn): + hstore_oids = self._hstore_oids(conn) + if hstore_oids is not None: + oid, array_oid = hstore_oids + kw = {"oid": oid} + if util.py2k: + kw["unicode"] = True + kw["array_oid"] = array_oid + extras.register_hstore(conn, **kw) + + fns.append(on_connect) + + if self.dbapi and self._json_deserializer: + + def on_connect(conn): + extras.register_default_json( + conn, loads=self._json_deserializer + ) + extras.register_default_jsonb( + conn, loads=self._json_deserializer + ) + + fns.append(on_connect) + + if fns: + + def on_connect(conn): + for fn in fns: + fn(conn) + + return on_connect + else: + return None + + def do_executemany(self, cursor, statement, parameters, context=None): + if ( + self.executemany_mode & EXECUTEMANY_VALUES + and context + and context.isinsert + and context.compiled._is_safe_for_fast_insert_values_helper + ): + executemany_values = ( + "(%s)" % context.compiled.insert_single_values_expr + ) + if not self.supports_unicode_statements: + executemany_values = executemany_values.encode(self.encoding) + + # guard for statement that was altered via event hook or similar + if executemany_values not in statement: + executemany_values = None + else: + executemany_values = None + + if executemany_values: + statement = statement.replace(executemany_values, "%s") + if self.executemany_values_page_size: + kwargs = {"page_size": self.executemany_values_page_size} + else: + kwargs = {} + xtras = self._psycopg2_extras() + context._psycopg2_fetched_rows = xtras.execute_values( + cursor, + statement, + parameters, + template=executemany_values, + fetch=bool(context.compiled.returning), + **kwargs + ) + + elif self.executemany_mode & EXECUTEMANY_BATCH: + if self.executemany_batch_page_size: + kwargs = {"page_size": self.executemany_batch_page_size} + else: + kwargs = {} + self._psycopg2_extras().execute_batch( + cursor, statement, parameters, **kwargs + ) + else: + cursor.executemany(statement, parameters) + + @util.memoized_instancemethod + def _hstore_oids(self, conn): + extras = self._psycopg2_extras() + if hasattr(conn, "dbapi_connection"): + conn = conn.dbapi_connection + oids = extras.HstoreAdapter.get_oids(conn) + if oids is not None and oids[0]: + return oids[0:2] + else: + return None + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + + is_multihost = False + if "host" in url.query: + is_multihost = isinstance(url.query["host"], (list, tuple)) + + if opts or url.query: + if not opts: + opts = {} + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + if is_multihost: + hosts, ports = zip( + *[ + token.split(":") if ":" in token else (token, "") + for token in url.query["host"] + ] + ) + opts["host"] = ",".join(hosts) + if "port" in opts: + raise exc.ArgumentError( + "Can't mix 'multihost' formats together; use " + '"host=h1,h2,h3&port=p1,p2,p3" or ' + '"host=h1:p1&host=h2:p2&host=h3:p3" separately' + ) + opts["port"] = ",".join(ports) + return ([], opts) + else: + # no connection arguments whatsoever; psycopg2.connect() + # requires that "dsn" be present as a blank string. + return ([""], opts) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + # check the "closed" flag. this might not be + # present on old psycopg2 versions. Also, + # this flag doesn't actually help in a lot of disconnect + # situations, so don't rely on it. + if getattr(connection, "closed", False): + return True + + # checks based on strings. in the case that .closed + # didn't cut it, fall back onto these. + str_e = str(e).partition("\n")[0] + for msg in [ + # these error messages from libpq: interfaces/libpq/fe-misc.c + # and interfaces/libpq/fe-secure.c. + "terminating connection", + "closed the connection", + "connection not open", + "could not receive data from server", + "could not send data to server", + # psycopg2 client errors, psycopg2/connection.h, + # psycopg2/cursor.h + "connection already closed", + "cursor already closed", + # not sure where this path is originally from, it may + # be obsolete. It really says "losed", not "closed". + "losed the connection unexpectedly", + # these can occur in newer SSL + "connection has been closed unexpectedly", + "SSL error: decryption failed or bad record mac", + "SSL SYSCALL error: Bad file descriptor", + "SSL SYSCALL error: EOF detected", + "SSL SYSCALL error: Operation timed out", + "SSL SYSCALL error: Bad address", + ]: + idx = str_e.find(msg) + if idx >= 0 and '"' not in str_e[:idx]: + return True + return False + + +dialect = PGDialect_psycopg2 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py new file mode 100644 index 0000000..10d1aae --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -0,0 +1,60 @@ +# testing/engines.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +r""" +.. dialect:: postgresql+psycopg2cffi + :name: psycopg2cffi + :dbapi: psycopg2cffi + :connectstring: postgresql+psycopg2cffi://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg2cffi/ + +``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C +layer. This makes it suitable for use in e.g. PyPy. Documentation +is as per ``psycopg2``. + +.. versionadded:: 1.0.0 + +.. seealso:: + + :mod:`sqlalchemy.dialects.postgresql.psycopg2` + +""" # noqa +from .psycopg2 import PGDialect_psycopg2 + + +class PGDialect_psycopg2cffi(PGDialect_psycopg2): + driver = "psycopg2cffi" + supports_unicode_statements = True + supports_statement_cache = True + + # psycopg2cffi's first release is 2.5.0, but reports + # __version__ as 2.4.4. Subsequent releases seem to have + # fixed this. + + FEATURE_VERSION_MAP = dict( + native_json=(2, 4, 4), + native_jsonb=(2, 7, 1), + sane_multi_rowcount=(2, 4, 4), + array_oid=(2, 4, 4), + hstore_adapter=(2, 4, 4), + ) + + @classmethod + def dbapi(cls): + return __import__("psycopg2cffi") + + @classmethod + def _psycopg2_extensions(cls): + root = __import__("psycopg2cffi", fromlist=["extensions"]) + return root.extensions + + @classmethod + def _psycopg2_extras(cls): + root = __import__("psycopg2cffi", fromlist=["extras"]) + return root.extras + + +dialect = PGDialect_psycopg2cffi diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py new file mode 100644 index 0000000..d273b8c --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py @@ -0,0 +1,278 @@ +# postgresql/pygresql.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" +.. dialect:: postgresql+pygresql + :name: pygresql + :dbapi: pgdb + :connectstring: postgresql+pygresql://user:password@host:port/dbname[?key=value&key=value...] + :url: https://www.pygresql.org/ + +.. note:: + + The pygresql dialect is **not tested as part of SQLAlchemy's continuous + integration** and may have unresolved issues. The recommended PostgreSQL + dialect is psycopg2. + +.. deprecated:: 1.4 The pygresql DBAPI is deprecated and will be removed + in a future version. Please use one of the supported DBAPIs to + connect to PostgreSQL. + +""" # noqa + +import decimal +import re + +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import PGCompiler +from .base import PGDialect +from .base import PGIdentifierPreparer +from .base import UUID +from .hstore import HSTORE +from .json import JSON +from .json import JSONB +from ... import exc +from ... import processors +from ... import util +from ...sql.elements import Null +from ...types import JSON as Json +from ...types import Numeric + + +class _PGNumeric(Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if not isinstance(coltype, int): + coltype = coltype.oid + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # PyGreSQL returns Decimal natively for 1700 (numeric) + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # PyGreSQL returns float natively for 701 (float8) + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PGHStore(HSTORE): + def bind_processor(self, dialect): + if not dialect.has_native_hstore: + return super(_PGHStore, self).bind_processor(dialect) + hstore = dialect.dbapi.Hstore + + def process(value): + if isinstance(value, dict): + return hstore(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not dialect.has_native_hstore: + return super(_PGHStore, self).result_processor(dialect, coltype) + + +class _PGJSON(JSON): + def bind_processor(self, dialect): + if not dialect.has_native_json: + return super(_PGJSON, self).bind_processor(dialect) + json = dialect.dbapi.Json + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, Null) or ( + value is None and self.none_as_null + ): + return None + if value is None or isinstance(value, (dict, list)): + return json(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not dialect.has_native_json: + return super(_PGJSON, self).result_processor(dialect, coltype) + + +class _PGJSONB(JSONB): + def bind_processor(self, dialect): + if not dialect.has_native_json: + return super(_PGJSONB, self).bind_processor(dialect) + json = dialect.dbapi.Json + + def process(value): + if value is self.NULL: + value = None + elif isinstance(value, Null) or ( + value is None and self.none_as_null + ): + return None + if value is None or isinstance(value, (dict, list)): + return json(value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not dialect.has_native_json: + return super(_PGJSONB, self).result_processor(dialect, coltype) + + +class _PGUUID(UUID): + def bind_processor(self, dialect): + if not dialect.has_native_uuid: + return super(_PGUUID, self).bind_processor(dialect) + uuid = dialect.dbapi.Uuid + + def process(value): + if value is None: + return None + if isinstance(value, (str, bytes)): + if len(value) == 16: + return uuid(bytes=value) + return uuid(value) + if isinstance(value, int): + return uuid(int=value) + return value + + return process + + def result_processor(self, dialect, coltype): + if not dialect.has_native_uuid: + return super(_PGUUID, self).result_processor(dialect, coltype) + if not self.as_uuid: + + def process(value): + if value is not None: + return str(value) + + return process + + +class _PGCompiler(PGCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) + + def post_process_text(self, text): + return text.replace("%", "%%") + + +class _PGIdentifierPreparer(PGIdentifierPreparer): + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value.replace("%", "%%") + + +class PGDialect_pygresql(PGDialect): + + driver = "pygresql" + supports_statement_cache = True + + statement_compiler = _PGCompiler + preparer = _PGIdentifierPreparer + + @classmethod + def dbapi(cls): + import pgdb + + util.warn_deprecated( + "The pygresql DBAPI is deprecated and will be removed " + "in a future version. Please use one of the supported DBAPIs to " + "connect to PostgreSQL.", + version="1.4", + ) + + return pgdb + + colspecs = util.update_copy( + PGDialect.colspecs, + { + Numeric: _PGNumeric, + HSTORE: _PGHStore, + Json: _PGJSON, + JSON: _PGJSON, + JSONB: _PGJSONB, + UUID: _PGUUID, + }, + ) + + def __init__(self, **kwargs): + super(PGDialect_pygresql, self).__init__(**kwargs) + try: + version = self.dbapi.version + m = re.match(r"(\d+)\.(\d+)", version) + version = (int(m.group(1)), int(m.group(2))) + except (AttributeError, ValueError, TypeError): + version = (0, 0) + self.dbapi_version = version + if version < (5, 0): + has_native_hstore = has_native_json = has_native_uuid = False + if version != (0, 0): + util.warn( + "PyGreSQL is only fully supported by SQLAlchemy" + " since version 5.0." + ) + else: + self.supports_unicode_statements = True + self.supports_unicode_binds = True + has_native_hstore = has_native_json = has_native_uuid = True + self.has_native_hstore = has_native_hstore + self.has_native_json = has_native_json + self.has_native_uuid = has_native_uuid + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["host"] = "%s:%s" % ( + opts.get("host", "").rsplit(":", 1)[0], + opts.pop("port"), + ) + opts.update(url.query) + return [], opts + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + if not connection: + return False + try: + connection = connection.connection + except AttributeError: + pass + else: + if not connection: + return False + try: + return connection.closed + except AttributeError: # PyGreSQL < 5.0 + return connection._cnx is None + return False + + +dialect = PGDialect_pygresql diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py new file mode 100644 index 0000000..886e368 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -0,0 +1,126 @@ +# postgresql/pypostgresql.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" +.. dialect:: postgresql+pypostgresql + :name: py-postgresql + :dbapi: pypostgresql + :connectstring: postgresql+pypostgresql://user:password@host:port/dbname[?key=value&key=value...] + :url: https://python.projects.pgfoundry.org/ + +.. note:: + + The pypostgresql dialect is **not tested as part of SQLAlchemy's continuous + integration** and may have unresolved issues. The recommended PostgreSQL + driver is psycopg2. + +.. deprecated:: 1.4 The py-postgresql DBAPI is deprecated and will be removed + in a future version. This DBAPI is superseded by the external + version available at external-dialect_. Please use the external version or + one of the supported DBAPIs to connect to PostgreSQL. + +.. TODO update link +.. _external-dialect: https://github.com/PyGreSQL + +""" # noqa + +from .base import PGDialect +from .base import PGExecutionContext +from ... import processors +from ... import types as sqltypes +from ... import util + + +class PGNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return processors.to_str + + def result_processor(self, dialect, coltype): + if self.asdecimal: + return None + else: + return processors.to_float + + +class PGExecutionContext_pypostgresql(PGExecutionContext): + pass + + +class PGDialect_pypostgresql(PGDialect): + driver = "pypostgresql" + + supports_statement_cache = True + supports_unicode_statements = True + supports_unicode_binds = True + description_encoding = None + default_paramstyle = "pyformat" + + # requires trunk version to support sane rowcounts + # TODO: use dbapi version information to set this flag appropriately + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + + execution_ctx_cls = PGExecutionContext_pypostgresql + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: PGNumeric, + # prevents PGNumeric from being used + sqltypes.Float: sqltypes.Float, + }, + ) + + @classmethod + def dbapi(cls): + from postgresql.driver import dbapi20 + + # TODO update link + util.warn_deprecated( + "The py-postgresql DBAPI is deprecated and will be removed " + "in a future version. This DBAPI is superseded by the external" + "version available at https://github.com/PyGreSQL. Please " + "use one of the supported DBAPIs to connect to PostgreSQL.", + version="1.4", + ) + + return dbapi20 + + _DBAPI_ERROR_NAMES = [ + "Error", + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", + ] + + @util.memoized_property + def dbapi_exception_translation_map(self): + if self.dbapi is None: + return {} + + return dict( + (getattr(self.dbapi, name).__name__, name) + for name in self._DBAPI_ERROR_NAMES + ) + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) + else: + opts["port"] = 5432 + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + return "connection is closed" in str(e) + + +dialect = PGDialect_pypostgresql diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py new file mode 100644 index 0000000..51f3b04 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -0,0 +1,138 @@ +# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from ... import types as sqltypes + + +__all__ = ("INT4RANGE", "INT8RANGE", "NUMRANGE") + + +class RangeOperators(object): + """ + This mixin provides functionality for the Range Operators + listed in the Range Operators table of the `PostgreSQL documentation`__ + for Range Functions and Operators. It is used by all the range types + provided in the ``postgres`` dialect and can likely be used for + any range types you create yourself. + + __ https://www.postgresql.org/docs/current/static/functions-range.html + + No extra support is provided for the Range Functions listed in the Range + Functions table of the PostgreSQL documentation. For these, the normal + :func:`~sqlalchemy.sql.expression.func` object should be used. + + """ + + class comparator_factory(sqltypes.Concatenable.Comparator): + """Define comparison operations for range types.""" + + def __ne__(self, other): + "Boolean expression. Returns true if two ranges are not equal" + if other is None: + return super(RangeOperators.comparator_factory, self).__ne__( + other + ) + else: + return self.expr.op("<>", is_comparison=True)(other) + + def contains(self, other, **kw): + """Boolean expression. Returns true if the right hand operand, + which can be an element or a range, is contained within the + column. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.expr.op("@>", is_comparison=True)(other) + + def contained_by(self, other): + """Boolean expression. Returns true if the column is contained + within the right hand operand. + """ + return self.expr.op("<@", is_comparison=True)(other) + + def overlaps(self, other): + """Boolean expression. Returns true if the column overlaps + (has points in common with) the right hand operand. + """ + return self.expr.op("&&", is_comparison=True)(other) + + def strictly_left_of(self, other): + """Boolean expression. Returns true if the column is strictly + left of the right hand operand. + """ + return self.expr.op("<<", is_comparison=True)(other) + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other): + """Boolean expression. Returns true if the column is strictly + right of the right hand operand. + """ + return self.expr.op(">>", is_comparison=True)(other) + + __rshift__ = strictly_right_of + + def not_extend_right_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend right of the range in the operand. + """ + return self.expr.op("&<", is_comparison=True)(other) + + def not_extend_left_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend left of the range in the operand. + """ + return self.expr.op("&>", is_comparison=True)(other) + + def adjacent_to(self, other): + """Boolean expression. Returns true if the range in the column + is adjacent to the range in the operand. + """ + return self.expr.op("-|-", is_comparison=True)(other) + + def __add__(self, other): + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.op("+")(other) + + +class INT4RANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the PostgreSQL INT4RANGE type.""" + + __visit_name__ = "INT4RANGE" + + +class INT8RANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the PostgreSQL INT8RANGE type.""" + + __visit_name__ = "INT8RANGE" + + +class NUMRANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the PostgreSQL NUMRANGE type.""" + + __visit_name__ = "NUMRANGE" + + +class DATERANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the PostgreSQL DATERANGE type.""" + + __visit_name__ = "DATERANGE" + + +class TSRANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the PostgreSQL TSRANGE type.""" + + __visit_name__ = "TSRANGE" + + +class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): + """Represent the PostgreSQL TSTZRANGE type.""" + + __visit_name__ = "TSTZRANGE" diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py new file mode 100644 index 0000000..8d8d933 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -0,0 +1,58 @@ +# sqlite/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from . import base # noqa +from . import pysqlcipher # noqa +from . import pysqlite # noqa +from .base import BLOB +from .base import BOOLEAN +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DECIMAL +from .base import FLOAT +from .base import INTEGER +from .base import JSON +from .base import NUMERIC +from .base import REAL +from .base import SMALLINT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import VARCHAR +from .dml import Insert +from .dml import insert +from ...util import compat + +if compat.py3k: + from . import aiosqlite # noqa + +# default dialect +base.dialect = dialect = pysqlite.dialect + + +__all__ = ( + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "FLOAT", + "INTEGER", + "JSON", + "NUMERIC", + "SMALLINT", + "TEXT", + "TIME", + "TIMESTAMP", + "VARCHAR", + "REAL", + "Insert", + "insert", + "dialect", +) diff --git a/lib/sqlalchemy/dialects/sqlite/aiosqlite.py b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py new file mode 100644 index 0000000..9fc6d35 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -0,0 +1,335 @@ +# sqlite/aiosqlite.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" + +.. dialect:: sqlite+aiosqlite + :name: aiosqlite + :dbapi: aiosqlite + :connectstring: sqlite+aiosqlite:///file_path + :url: https://pypi.org/project/aiosqlite/ + +The aiosqlite dialect provides support for the SQLAlchemy asyncio interface +running on top of pysqlite. + +aiosqlite is a wrapper around pysqlite that uses a background thread for +each connection. It does not actually use non-blocking IO, as SQLite +databases are not socket-based. However it does provide a working asyncio +interface that's useful for testing and prototyping purposes. + +Using a special asyncio mediation layer, the aiosqlite dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("sqlite+aiosqlite:///filename") + +The URL passes through all arguments to the ``pysqlite`` driver, so all +connection arguments are the same as they are for that of :ref:`pysqlite`. + + +""" # noqa + +from .base import SQLiteExecutionContext +from .pysqlite import SQLiteDialect_pysqlite +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiosqlite_cursor: + __slots__ = ( + "_adapt_connection", + "_connection", + "description", + "await_", + "_rows", + "arraysize", + "rowcount", + "lastrowid", + ) + + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + self.arraysize = 1 + self.rowcount = -1 + self.description = None + self._rows = [] + + def close(self): + self._rows[:] = [] + + def execute(self, operation, parameters=None): + try: + _cursor = self.await_(self._connection.cursor()) + + if parameters is None: + self.await_(_cursor.execute(operation)) + else: + self.await_(_cursor.execute(operation, parameters)) + + if _cursor.description: + self.description = _cursor.description + self.lastrowid = self.rowcount = -1 + + if not self.server_side: + self._rows = self.await_(_cursor.fetchall()) + else: + self.description = None + self.lastrowid = _cursor.lastrowid + self.rowcount = _cursor.rowcount + + if not self.server_side: + self.await_(_cursor.close()) + else: + self._cursor = _cursor + except Exception as error: + self._adapt_connection._handle_exception(error) + + def executemany(self, operation, seq_of_parameters): + try: + _cursor = self.await_(self._connection.cursor()) + self.await_(_cursor.executemany(operation, seq_of_parameters)) + self.description = None + self.lastrowid = _cursor.lastrowid + self.rowcount = _cursor.rowcount + self.await_(_cursor.close()) + except Exception as error: + self._adapt_connection._handle_exception(error) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor): + __slots__ = "_cursor" + + server_side = True + + def __init__(self, *arg, **kw): + super().__init__(*arg, **kw) + self._cursor = None + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiosqlite_connection(AdaptedConnection): + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_connection") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + + @property + def isolation_level(self): + return self._connection.isolation_level + + @isolation_level.setter + def isolation_level(self, value): + try: + self._connection.isolation_level = value + except Exception as error: + self._handle_exception(error) + + def create_function(self, *args, **kw): + try: + self.await_(self._connection.create_function(*args, **kw)) + except Exception as error: + self._handle_exception(error) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiosqlite_ss_cursor(self) + else: + return AsyncAdapt_aiosqlite_cursor(self) + + def execute(self, *args, **kw): + return self.await_(self._connection.execute(*args, **kw)) + + def rollback(self): + try: + self.await_(self._connection.rollback()) + except Exception as error: + self._handle_exception(error) + + def commit(self): + try: + self.await_(self._connection.commit()) + except Exception as error: + self._handle_exception(error) + + def close(self): + try: + self.await_(self._connection.close()) + except Exception as error: + self._handle_exception(error) + + def _handle_exception(self, error): + if ( + isinstance(error, ValueError) + and error.args[0] == "no active connection" + ): + util.raise_( + self.dbapi.sqlite.OperationalError("no active connection"), + from_=error, + ) + else: + raise error + + +class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiosqlite_dbapi: + def __init__(self, aiosqlite, sqlite): + self.aiosqlite = aiosqlite + self.sqlite = sqlite + self.paramstyle = "qmark" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "DatabaseError", + "Error", + "IntegrityError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "sqlite_version", + "sqlite_version_info", + ): + setattr(self, name, getattr(self.aiosqlite, name)) + + for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"): + setattr(self, name, getattr(self.sqlite, name)) + + for name in ("Binary",): + setattr(self, name, getattr(self.sqlite, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + # Q. WHY do we need this? + # A. Because there is no way to set connection.isolation_level + # otherwise + # Q. BUT HOW do you know it is SAFE ????? + # A. The only operation that isn't safe is the isolation level set + # operation which aiosqlite appears to have let slip through even + # though pysqlite appears to do check_same_thread for this. + # All execute operations etc. should be safe because they all + # go through the single executor thread. + + kw["check_same_thread"] = False + + connection = self.aiosqlite.connect(*arg, **kw) + + # it's a Thread. you'll thank us later + connection.daemon = True + + if util.asbool(async_fallback): + return AsyncAdaptFallback_aiosqlite_connection( + self, + await_fallback(connection), + ) + else: + return AsyncAdapt_aiosqlite_connection( + self, + await_only(connection), + ) + + +class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext): + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): + driver = "aiosqlite" + supports_statement_cache = True + + is_async = True + + supports_server_side_cursors = True + + execution_ctx_cls = SQLiteExecutionContext_aiosqlite + + @classmethod + def dbapi(cls): + return AsyncAdapt_aiosqlite_dbapi( + __import__("aiosqlite"), __import__("sqlite3") + ) + + @classmethod + def get_pool_class(cls, url): + if cls._is_url_file_db(url): + return pool.NullPool + else: + return pool.StaticPool + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, self.dbapi.OperationalError + ) and "no active connection" in str(e): + return True + + return super().is_disconnect(e, connection, cursor) + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = SQLiteDialect_aiosqlite diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py new file mode 100644 index 0000000..0959d04 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -0,0 +1,2556 @@ +# sqlite/base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" +.. dialect:: sqlite + :name: SQLite + :full_support: 3.21, 3.28+ + :normal_support: 3.12+ + :best_effort: 3.7.16+ + +.. _sqlite_datetime: + +Date and Time Types +------------------- + +SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does +not provide out of the box functionality for translating values between Python +`datetime` objects and a SQLite-supported format. SQLAlchemy's own +:class:`~sqlalchemy.types.DateTime` and related types provide date formatting +and parsing functionality when SQLite is used. The implementation classes are +:class:`_sqlite.DATETIME`, :class:`_sqlite.DATE` and :class:`_sqlite.TIME`. +These types represent dates and times as ISO formatted strings, which also +nicely support ordering. There's no reliance on typical "libc" internals for +these functions so historical dates are fully supported. + +Ensuring Text affinity +^^^^^^^^^^^^^^^^^^^^^^ + +The DDL rendered for these types is the standard ``DATE``, ``TIME`` +and ``DATETIME`` indicators. However, custom storage formats can also be +applied to these types. When the +storage format is detected as containing no alpha characters, the DDL for +these types is rendered as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``, +so that the column continues to have textual affinity. + +.. seealso:: + + `Type Affinity <https://www.sqlite.org/datatype3.html#affinity>`_ - + in the SQLite documentation + +.. _sqlite_autoincrement: + +SQLite Auto Incrementing Behavior +---------------------------------- + +Background on SQLite's autoincrement is at: https://sqlite.org/autoinc.html + +Key concepts: + +* SQLite has an implicit "auto increment" feature that takes place for any + non-composite primary-key column that is specifically created using + "INTEGER PRIMARY KEY" for the type + primary key. + +* SQLite also has an explicit "AUTOINCREMENT" keyword, that is **not** + equivalent to the implicit autoincrement feature; this keyword is not + recommended for general use. SQLAlchemy does not render this keyword + unless a special SQLite-specific directive is used (see below). However, + it still requires that the column's type is named "INTEGER". + +Using the AUTOINCREMENT Keyword +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To specifically render the AUTOINCREMENT keyword on the primary key column +when rendering DDL, add the flag ``sqlite_autoincrement=True`` to the Table +construct:: + + Table('sometable', metadata, + Column('id', Integer, primary_key=True), + sqlite_autoincrement=True) + +Allowing autoincrement behavior SQLAlchemy types other than Integer/INTEGER +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQLite's typing model is based on naming conventions. Among other things, this +means that any type name which contains the substring ``"INT"`` will be +determined to be of "integer affinity". A type named ``"BIGINT"``, +``"SPECIAL_INT"`` or even ``"XYZINTQPR"``, will be considered by SQLite to be +of "integer" affinity. However, **the SQLite autoincrement feature, whether +implicitly or explicitly enabled, requires that the name of the column's type +is exactly the string "INTEGER"**. Therefore, if an application uses a type +like :class:`.BigInteger` for a primary key, on SQLite this type will need to +be rendered as the name ``"INTEGER"`` when emitting the initial ``CREATE +TABLE`` statement in order for the autoincrement behavior to be available. + +One approach to achieve this is to use :class:`.Integer` on SQLite +only using :meth:`.TypeEngine.with_variant`:: + + table = Table( + "my_table", metadata, + Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) + ) + +Another is to use a subclass of :class:`.BigInteger` that overrides its DDL +name to be ``INTEGER`` when compiled against SQLite:: + + from sqlalchemy import BigInteger + from sqlalchemy.ext.compiler import compiles + + class SLBigInteger(BigInteger): + pass + + @compiles(SLBigInteger, 'sqlite') + def bi_c(element, compiler, **kw): + return "INTEGER" + + @compiles(SLBigInteger) + def bi_c(element, compiler, **kw): + return compiler.visit_BIGINT(element, **kw) + + + table = Table( + "my_table", metadata, + Column("id", SLBigInteger(), primary_key=True) + ) + +.. seealso:: + + :meth:`.TypeEngine.with_variant` + + :ref:`sqlalchemy.ext.compiler_toplevel` + + `Datatypes In SQLite Version 3 <https://sqlite.org/datatype3.html>`_ + +.. _sqlite_concurrency: + +Database Locking Behavior / Concurrency +--------------------------------------- + +SQLite is not designed for a high level of write concurrency. The database +itself, being a file, is locked completely during write operations within +transactions, meaning exactly one "connection" (in reality a file handle) +has exclusive access to the database during this period - all other +"connections" will be blocked during this time. + +The Python DBAPI specification also calls for a connection model that is +always in a transaction; there is no ``connection.begin()`` method, +only ``connection.commit()`` and ``connection.rollback()``, upon which a +new transaction is to be begun immediately. This may seem to imply +that the SQLite driver would in theory allow only a single filehandle on a +particular database file at any time; however, there are several +factors both within SQLite itself as well as within the pysqlite driver +which loosen this restriction significantly. + +However, no matter what locking modes are used, SQLite will still always +lock the database file once a transaction is started and DML (e.g. INSERT, +UPDATE, DELETE) has at least been emitted, and this will block +other transactions at least at the point that they also attempt to emit DML. +By default, the length of time on this block is very short before it times out +with an error. + +This behavior becomes more critical when used in conjunction with the +SQLAlchemy ORM. SQLAlchemy's :class:`.Session` object by default runs +within a transaction, and with its autoflush model, may emit DML preceding +any SELECT statement. This may lead to a SQLite database that locks +more quickly than is expected. The locking mode of SQLite and the pysqlite +driver can be manipulated to some degree, however it should be noted that +achieving a high degree of write-concurrency with SQLite is a losing battle. + +For more information on SQLite's lack of write concurrency by design, please +see +`Situations Where Another RDBMS May Work Better - High Concurrency +<https://www.sqlite.org/whentouse.html>`_ near the bottom of the page. + +The following subsections introduce areas that are impacted by SQLite's +file-based architecture and additionally will usually require workarounds to +work when using the pysqlite driver. + +.. _sqlite_isolation_level: + +Transaction Isolation Level / Autocommit +---------------------------------------- + +SQLite supports "transaction isolation" in a non-standard way, along two +axes. One is that of the +`PRAGMA read_uncommitted <https://www.sqlite.org/pragma.html#pragma_read_uncommitted>`_ +instruction. This setting can essentially switch SQLite between its +default mode of ``SERIALIZABLE`` isolation, and a "dirty read" isolation +mode normally referred to as ``READ UNCOMMITTED``. + +SQLAlchemy ties into this PRAGMA statement using the +:paramref:`_sa.create_engine.isolation_level` parameter of +:func:`_sa.create_engine`. +Valid values for this parameter when used with SQLite are ``"SERIALIZABLE"`` +and ``"READ UNCOMMITTED"`` corresponding to a value of 0 and 1, respectively. +SQLite defaults to ``SERIALIZABLE``, however its behavior is impacted by +the pysqlite driver's default behavior. + +When using the pysqlite driver, the ``"AUTOCOMMIT"`` isolation level is also +available, which will alter the pysqlite connection using the ``.isolation_level`` +attribute on the DBAPI connection and set it to None for the duration +of the setting. + +.. versionadded:: 1.3.16 added support for SQLite AUTOCOMMIT isolation level + when using the pysqlite / sqlite3 SQLite driver. + + +The other axis along which SQLite's transactional locking is impacted is +via the nature of the ``BEGIN`` statement used. The three varieties +are "deferred", "immediate", and "exclusive", as described at +`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_. A straight +``BEGIN`` statement uses the "deferred" mode, where the database file is +not locked until the first read or write operation, and read access remains +open to other transactions until the first write operation. But again, +it is critical to note that the pysqlite driver interferes with this behavior +by *not even emitting BEGIN* until the first write operation. + +.. warning:: + + SQLite's transactional scope is impacted by unresolved + issues in the pysqlite driver, which defers BEGIN statements to a greater + degree than is often feasible. See the section :ref:`pysqlite_serializable` + for techniques to work around this behavior. + +.. seealso:: + + :ref:`dbapi_autocommit` + +SAVEPOINT Support +---------------------------- + +SQLite supports SAVEPOINTs, which only function once a transaction is +begun. SQLAlchemy's SAVEPOINT support is available using the +:meth:`_engine.Connection.begin_nested` method at the Core level, and +:meth:`.Session.begin_nested` at the ORM level. However, SAVEPOINTs +won't work at all with pysqlite unless workarounds are taken. + +.. warning:: + + SQLite's SAVEPOINT feature is impacted by unresolved + issues in the pysqlite driver, which defers BEGIN statements to a greater + degree than is often feasible. See the section :ref:`pysqlite_serializable` + for techniques to work around this behavior. + +Transactional DDL +---------------------------- + +The SQLite database supports transactional :term:`DDL` as well. +In this case, the pysqlite driver is not only failing to start transactions, +it also is ending any existing transaction when DDL is detected, so again, +workarounds are required. + +.. warning:: + + SQLite's transactional DDL is impacted by unresolved issues + in the pysqlite driver, which fails to emit BEGIN and additionally + forces a COMMIT to cancel any transaction when DDL is encountered. + See the section :ref:`pysqlite_serializable` + for techniques to work around this behavior. + +.. _sqlite_foreign_keys: + +Foreign Key Support +------------------- + +SQLite supports FOREIGN KEY syntax when emitting CREATE statements for tables, +however by default these constraints have no effect on the operation of the +table. + +Constraint checking on SQLite has three prerequisites: + +* At least version 3.6.19 of SQLite must be in use +* The SQLite library must be compiled *without* the SQLITE_OMIT_FOREIGN_KEY + or SQLITE_OMIT_TRIGGER symbols enabled. +* The ``PRAGMA foreign_keys = ON`` statement must be emitted on all + connections before use -- including the initial call to + :meth:`sqlalchemy.schema.MetaData.create_all`. + +SQLAlchemy allows for the ``PRAGMA`` statement to be emitted automatically for +new connections through the usage of events:: + + from sqlalchemy.engine import Engine + from sqlalchemy import event + + @event.listens_for(Engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + +.. warning:: + + When SQLite foreign keys are enabled, it is **not possible** + to emit CREATE or DROP statements for tables that contain + mutually-dependent foreign key constraints; + to emit the DDL for these tables requires that ALTER TABLE be used to + create or drop these constraints separately, for which SQLite has + no support. + +.. seealso:: + + `SQLite Foreign Key Support <https://www.sqlite.org/foreignkeys.html>`_ + - on the SQLite web site. + + :ref:`event_toplevel` - SQLAlchemy event API. + + :ref:`use_alter` - more information on SQLAlchemy's facilities for handling + mutually-dependent foreign key constraints. + +.. _sqlite_on_conflict_ddl: + +ON CONFLICT support for constraints +----------------------------------- + +.. seealso:: This section describes the :term:`DDL` version of "ON CONFLICT" for + SQLite, which occurs within a CREATE TABLE statement. For "ON CONFLICT" as + applied to an INSERT statement, see :ref:`sqlite_on_conflict_insert`. + +SQLite supports a non-standard DDL clause known as ON CONFLICT which can be applied +to primary key, unique, check, and not null constraints. In DDL, it is +rendered either within the "CONSTRAINT" clause or within the column definition +itself depending on the location of the target constraint. To render this +clause within DDL, the extension parameter ``sqlite_on_conflict`` can be +specified with a string conflict resolution algorithm within the +:class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`, +:class:`.CheckConstraint` objects. Within the :class:`_schema.Column` object, +there +are individual parameters ``sqlite_on_conflict_not_null``, +``sqlite_on_conflict_primary_key``, ``sqlite_on_conflict_unique`` which each +correspond to the three types of relevant constraint types that can be +indicated from a :class:`_schema.Column` object. + +.. seealso:: + + `ON CONFLICT <https://www.sqlite.org/lang_conflict.html>`_ - in the SQLite + documentation + +.. versionadded:: 1.3 + + +The ``sqlite_on_conflict`` parameters accept a string argument which is just +the resolution name to be chosen, which on SQLite can be one of ROLLBACK, +ABORT, FAIL, IGNORE, and REPLACE. For example, to add a UNIQUE constraint +that specifies the IGNORE algorithm:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer), + UniqueConstraint('id', 'data', sqlite_on_conflict='IGNORE') + ) + +The above renders CREATE TABLE DDL as:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER, + PRIMARY KEY (id), + UNIQUE (id, data) ON CONFLICT IGNORE + ) + + +When using the :paramref:`_schema.Column.unique` +flag to add a UNIQUE constraint +to a single column, the ``sqlite_on_conflict_unique`` parameter can +be added to the :class:`_schema.Column` as well, which will be added to the +UNIQUE constraint in the DDL:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer, unique=True, + sqlite_on_conflict_unique='IGNORE') + ) + +rendering:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER, + PRIMARY KEY (id), + UNIQUE (data) ON CONFLICT IGNORE + ) + +To apply the FAIL algorithm for a NOT NULL constraint, +``sqlite_on_conflict_not_null`` is used:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer, nullable=False, + sqlite_on_conflict_not_null='FAIL') + ) + +this renders the column inline ON CONFLICT phrase:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER NOT NULL ON CONFLICT FAIL, + PRIMARY KEY (id) + ) + + +Similarly, for an inline primary key, use ``sqlite_on_conflict_primary_key``:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True, + sqlite_on_conflict_primary_key='FAIL') + ) + +SQLAlchemy renders the PRIMARY KEY constraint separately, so the conflict +resolution algorithm is applied to the constraint itself:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + PRIMARY KEY (id) ON CONFLICT FAIL + ) + +.. _sqlite_on_conflict_insert: + +INSERT...ON CONFLICT (Upsert) +----------------------------------- + +.. seealso:: This section describes the :term:`DML` version of "ON CONFLICT" for + SQLite, which occurs within an INSERT statement. For "ON CONFLICT" as + applied to a CREATE TABLE statement, see :ref:`sqlite_on_conflict_ddl`. + +From version 3.24.0 onwards, SQLite supports "upserts" (update or insert) +of rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` +statement. A candidate row will only be inserted if that row does not violate +any unique or primary key constraints. In the case of a unique constraint violation, a +secondary action can occur which can be either "DO UPDATE", indicating that +the data in the target row should be updated, or "DO NOTHING", which indicates +to silently skip this row. + +Conflicts are determined using columns that are part of existing unique +constraints and indexes. These constraints are identified by stating the +columns and conditions that comprise the indexes. + +SQLAlchemy provides ``ON CONFLICT`` support via the SQLite-specific +:func:`_sqlite.insert()` function, which provides +the generative methods :meth:`_sqlite.Insert.on_conflict_do_update` +and :meth:`_sqlite.Insert.on_conflict_do_nothing`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.sqlite import insert + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?{stop} + + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + ... index_elements=['id'] + ... ) + + >>> print(do_nothing_stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO NOTHING + +.. versionadded:: 1.4 + +.. seealso:: + + `Upsert + <https://sqlite.org/lang_UPSERT.html>`_ + - in the SQLite documentation. + + +Specifying the Target +^^^^^^^^^^^^^^^^^^^^^ + +Both methods supply the "target" of the conflict using column inference: + +* The :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` argument + specifies a sequence containing string column names, :class:`_schema.Column` + objects, and/or SQL expression elements, which would identify a unique index + or unique constraint. + +* When using :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` + to infer an index, a partial index can be inferred by also specifying the + :paramref:`_sqlite.Insert.on_conflict_do_update.index_where` parameter: + + .. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=[my_table.c.user_email], + ... index_where=my_table.c.user_email.like('%@gmail.com'), + ... set_=dict(data=stmt.excluded.data) + ... ) + + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (data, user_email) VALUES (?, ?) + ON CONFLICT (user_email) + WHERE user_email LIKE '%@gmail.com' + DO UPDATE SET data = excluded.data + >>> + +The SET Clause +^^^^^^^^^^^^^^^ + +``ON CONFLICT...DO UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are specified using the +:paramref:`_sqlite.Insert.on_conflict_do_update.set_` parameter. This +parameter accepts a dictionary which consists of direct values +for UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + + >>> print(do_update_stmt) + + {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO UPDATE SET data = ? + +.. warning:: + + The :meth:`_sqlite.Insert.on_conflict_do_update` method does **not** take + into account Python-side default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. These + values will not be exercised for an ON CONFLICT style of UPDATE, unless + they are manually specified in the + :paramref:`_sqlite.Insert.on_conflict_do_update.set_` dictionary. + +Updating using the Excluded INSERT Values +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to refer to the proposed insertion row, the special alias +:attr:`~.sqlite.Insert.excluded` is available as an attribute on +the :class:`_sqlite.Insert` object; this object creates an "excluded." prefix +on a column, that informs the DO UPDATE to update the row with the value that +would have been inserted had the constraint not failed: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author) + ... ) + + >>> print(do_update_stmt) + {opensql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author + +Additional WHERE Criteria +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`_sqlite.Insert.on_conflict_do_update` method also accepts +a WHERE clause using the :paramref:`_sqlite.Insert.on_conflict_do_update.where` +parameter, which will limit those rows which receive an UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + + >>> on_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author), + ... where=(my_table.c.status == 2) + ... ) + >>> print(on_update_stmt) + {opensql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author + WHERE my_table.status = ? + + +Skipping Rows with DO NOTHING +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ON CONFLICT`` may be used to skip inserting a row entirely +if any conflict with a unique constraint occurs; below this is illustrated +using the :meth:`_sqlite.Insert.on_conflict_do_nothing` method: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + >>> print(stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO NOTHING + + +If ``DO NOTHING`` is used without specifying any columns or constraint, +it has the effect of skipping the INSERT for any unique violation which +occurs: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing() + >>> print(stmt) + {opensql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT DO NOTHING + +.. _sqlite_type_reflection: + +Type Reflection +--------------- + +SQLite types are unlike those of most other database backends, in that +the string name of the type usually does not correspond to a "type" in a +one-to-one fashion. Instead, SQLite links per-column typing behavior +to one of five so-called "type affinities" based on a string matching +pattern for the type. + +SQLAlchemy's reflection process, when inspecting types, uses a simple +lookup table to link the keywords returned to provided SQLAlchemy types. +This lookup table is present within the SQLite dialect as it is for all +other dialects. However, the SQLite dialect has a different "fallback" +routine for when a particular type name is not located in the lookup map; +it instead implements the SQLite "type affinity" scheme located at +https://www.sqlite.org/datatype3.html section 2.1. + +The provided typemap will make direct associations from an exact string +name match for the following types: + +:class:`_types.BIGINT`, :class:`_types.BLOB`, +:class:`_types.BOOLEAN`, :class:`_types.BOOLEAN`, +:class:`_types.CHAR`, :class:`_types.DATE`, +:class:`_types.DATETIME`, :class:`_types.FLOAT`, +:class:`_types.DECIMAL`, :class:`_types.FLOAT`, +:class:`_types.INTEGER`, :class:`_types.INTEGER`, +:class:`_types.NUMERIC`, :class:`_types.REAL`, +:class:`_types.SMALLINT`, :class:`_types.TEXT`, +:class:`_types.TIME`, :class:`_types.TIMESTAMP`, +:class:`_types.VARCHAR`, :class:`_types.NVARCHAR`, +:class:`_types.NCHAR` + +When a type name does not match one of the above types, the "type affinity" +lookup is used instead: + +* :class:`_types.INTEGER` is returned if the type name includes the + string ``INT`` +* :class:`_types.TEXT` is returned if the type name includes the + string ``CHAR``, ``CLOB`` or ``TEXT`` +* :class:`_types.NullType` is returned if the type name includes the + string ``BLOB`` +* :class:`_types.REAL` is returned if the type name includes the string + ``REAL``, ``FLOA`` or ``DOUB``. +* Otherwise, the :class:`_types.NUMERIC` type is used. + +.. versionadded:: 0.9.3 Support for SQLite type affinity rules when reflecting + columns. + + +.. _sqlite_partial_index: + +Partial Indexes +--------------- + +A partial index, e.g. one which uses a WHERE clause, can be specified +with the DDL system using the argument ``sqlite_where``:: + + tbl = Table('testtbl', m, Column('data', Integer)) + idx = Index('test_idx1', tbl.c.data, + sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + +The index will be rendered at create time as:: + + CREATE INDEX test_idx1 ON testtbl (data) + WHERE data > 5 AND data < 10 + +.. versionadded:: 0.9.9 + +.. _sqlite_dotted_column_names: + +Dotted Column Names +------------------- + +Using table or column names that explicitly have periods in them is +**not recommended**. While this is generally a bad idea for relational +databases in general, as the dot is a syntactically significant character, +the SQLite driver up until version **3.10.0** of SQLite has a bug which +requires that SQLAlchemy filter out these dots in result sets. + +.. versionchanged:: 1.1 + + The following SQLite issue has been resolved as of version 3.10.0 + of SQLite. SQLAlchemy as of **1.1** automatically disables its internal + workarounds based on detection of this version. + +The bug, entirely outside of SQLAlchemy, can be illustrated thusly:: + + import sqlite3 + + assert sqlite3.sqlite_version_info < (3, 10, 0), "bug is fixed in this version" + + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + + cursor.execute("create table x (a integer, b integer)") + cursor.execute("insert into x (a, b) values (1, 1)") + cursor.execute("insert into x (a, b) values (2, 2)") + + cursor.execute("select x.a, x.b from x") + assert [c[0] for c in cursor.description] == ['a', 'b'] + + cursor.execute(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert [c[0] for c in cursor.description] == ['a', 'b'], \ + [c[0] for c in cursor.description] + +The second assertion fails:: + + Traceback (most recent call last): + File "test.py", line 19, in <module> + [c[0] for c in cursor.description] + AssertionError: ['x.a', 'x.b'] + +Where above, the driver incorrectly reports the names of the columns +including the name of the table, which is entirely inconsistent vs. +when the UNION is not present. + +SQLAlchemy relies upon column names being predictable in how they match +to the original statement, so the SQLAlchemy dialect has no choice but +to filter these out:: + + + from sqlalchemy import create_engine + + eng = create_engine("sqlite://") + conn = eng.connect() + + conn.exec_driver_sql("create table x (a integer, b integer)") + conn.exec_driver_sql("insert into x (a, b) values (1, 1)") + conn.exec_driver_sql("insert into x (a, b) values (2, 2)") + + result = conn.exec_driver_sql("select x.a, x.b from x") + assert result.keys() == ["a", "b"] + + result = conn.exec_driver_sql(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert result.keys() == ["a", "b"] + +Note that above, even though SQLAlchemy filters out the dots, *both +names are still addressable*:: + + >>> row = result.first() + >>> row["a"] + 1 + >>> row["x.a"] + 1 + >>> row["b"] + 1 + >>> row["x.b"] + 1 + +Therefore, the workaround applied by SQLAlchemy only impacts +:meth:`_engine.CursorResult.keys` and :meth:`.Row.keys()` in the public API. In +the very specific case where an application is forced to use column names that +contain dots, and the functionality of :meth:`_engine.CursorResult.keys` and +:meth:`.Row.keys()` is required to return these dotted names unmodified, +the ``sqlite_raw_colnames`` execution option may be provided, either on a +per-:class:`_engine.Connection` basis:: + + result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert result.keys() == ["x.a", "x.b"] + +or on a per-:class:`_engine.Engine` basis:: + + engine = create_engine("sqlite://", execution_options={"sqlite_raw_colnames": True}) + +When using the per-:class:`_engine.Engine` execution option, note that +**Core and ORM queries that use UNION may not function properly**. + +SQLite-specific table options +----------------------------- + +One option for CREATE TABLE is supported directly by the SQLite +dialect in conjunction with the :class:`_schema.Table` construct: + +* ``WITHOUT ROWID``:: + + Table("some_table", metadata, ..., sqlite_with_rowid=False) + +.. seealso:: + + `SQLite CREATE TABLE options + <https://www.sqlite.org/lang_createtable.html>`_ + +""" # noqa + +import datetime +import numbers +import re + +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from ... import exc +from ... import processors +from ... import schema as sa_schema +from ... import sql +from ... import types as sqltypes +from ... import util +from ...engine import default +from ...engine import reflection +from ...sql import coercions +from ...sql import ColumnElement +from ...sql import compiler +from ...sql import elements +from ...sql import roles +from ...sql import schema +from ...types import BLOB # noqa +from ...types import BOOLEAN # noqa +from ...types import CHAR # noqa +from ...types import DECIMAL # noqa +from ...types import FLOAT # noqa +from ...types import INTEGER # noqa +from ...types import NUMERIC # noqa +from ...types import REAL # noqa +from ...types import SMALLINT # noqa +from ...types import TEXT # noqa +from ...types import TIMESTAMP # noqa +from ...types import VARCHAR # noqa + + +class _SQliteJson(JSON): + def result_processor(self, dialect, coltype): + default_processor = super(_SQliteJson, self).result_processor( + dialect, coltype + ) + + def process(value): + try: + return default_processor(value) + except TypeError: + if isinstance(value, numbers.Number): + return value + else: + raise + + return process + + +class _DateTimeMixin(object): + _reg = None + _storage_format = None + + def __init__(self, storage_format=None, regexp=None, **kw): + super(_DateTimeMixin, self).__init__(**kw) + if regexp is not None: + self._reg = re.compile(regexp) + if storage_format is not None: + self._storage_format = storage_format + + @property + def format_is_text_affinity(self): + """return True if the storage format will automatically imply + a TEXT affinity. + + If the storage format contains no non-numeric characters, + it will imply a NUMERIC storage format on SQLite; in this case, + the type will generate its DDL as DATE_CHAR, DATETIME_CHAR, + TIME_CHAR. + + .. versionadded:: 1.0.0 + + """ + spec = self._storage_format % { + "year": 0, + "month": 0, + "day": 0, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, + } + return bool(re.search(r"[^0-9]", spec)) + + def adapt(self, cls, **kw): + if issubclass(cls, _DateTimeMixin): + if self._storage_format: + kw["storage_format"] = self._storage_format + if self._reg: + kw["regexp"] = self._reg + return super(_DateTimeMixin, self).adapt(cls, **kw) + + def literal_processor(self, dialect): + bp = self.bind_processor(dialect) + + def process(value): + return "'%s'" % bp(value) + + return process + + +class DATETIME(_DateTimeMixin, sqltypes.DateTime): + r"""Represent a Python datetime object in SQLite using a string. + + The default string storage format is:: + + "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + e.g.:: + + 2021-03-15 12:05:57.105542 + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import DATETIME + + dt = DATETIME(storage_format="%(year)04d/%(month)02d/%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d", + regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)" + ) + + :param storage_format: format string which will be applied to the dict + with keys year, month, day, hour, minute, second, and microsecond. + + :param regexp: regular expression which will be applied to incoming result + rows. If the regexp contains named groups, the resulting match dict is + applied to the Python datetime() constructor as keyword arguments. + Otherwise, if positional groups are used, the datetime() constructor + is called with positional arguments via + ``*map(int, match_obj.groups(0))``. + + """ # noqa + + _storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + ) + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop("truncate_microseconds", False) + super(DATETIME, self).__init__(*args, **kwargs) + if truncate_microseconds: + assert "storage_format" not in kwargs, ( + "You can specify only " + "one of truncate_microseconds or storage_format." + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " + "truncate_microseconds or regexp." + ) + self._storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d" + ) + + def bind_processor(self, dialect): + datetime_datetime = datetime.datetime + datetime_date = datetime.date + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_datetime): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, + } + elif isinstance(value, datetime_date): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, + } + else: + raise TypeError( + "SQLite DateTime type only accepts Python " + "datetime and date objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.datetime + ) + else: + return processors.str_to_datetime + + +class DATE(_DateTimeMixin, sqltypes.Date): + r"""Represent a Python date object in SQLite using a string. + + The default string storage format is:: + + "%(year)04d-%(month)02d-%(day)02d" + + e.g.:: + + 2011-03-15 + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import DATE + + d = DATE( + storage_format="%(month)02d/%(day)02d/%(year)04d", + regexp=re.compile("(?P<month>\d+)/(?P<day>\d+)/(?P<year>\d+)") + ) + + :param storage_format: format string which will be applied to the + dict with keys year, month, and day. + + :param regexp: regular expression which will be applied to + incoming result rows. If the regexp contains named groups, the + resulting match dict is applied to the Python date() constructor + as keyword arguments. Otherwise, if positional groups are used, the + date() constructor is called with positional arguments via + ``*map(int, match_obj.groups(0))``. + """ + + _storage_format = "%(year)04d-%(month)02d-%(day)02d" + + def bind_processor(self, dialect): + datetime_date = datetime.date + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_date): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + } + else: + raise TypeError( + "SQLite Date type only accepts Python " + "date objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.date + ) + else: + return processors.str_to_date + + +class TIME(_DateTimeMixin, sqltypes.Time): + r"""Represent a Python time object in SQLite using a string. + + The default string storage format is:: + + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + e.g.:: + + 12:05:57.10558 + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import TIME + + t = TIME(storage_format="%(hour)02d-%(minute)02d-" + "%(second)02d-%(microsecond)06d", + regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?") + ) + + :param storage_format: format string which will be applied to the dict + with keys hour, minute, second, and microsecond. + + :param regexp: regular expression which will be applied to incoming result + rows. If the regexp contains named groups, the resulting match dict is + applied to the Python time() constructor as keyword arguments. Otherwise, + if positional groups are used, the time() constructor is called with + positional arguments via ``*map(int, match_obj.groups(0))``. + """ + + _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop("truncate_microseconds", False) + super(TIME, self).__init__(*args, **kwargs) + if truncate_microseconds: + assert "storage_format" not in kwargs, ( + "You can specify only " + "one of truncate_microseconds or storage_format." + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " + "truncate_microseconds or regexp." + ) + self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d" + + def bind_processor(self, dialect): + datetime_time = datetime.time + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_time): + return format_ % { + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, + } + else: + raise TypeError( + "SQLite Time type only accepts Python " + "time objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.time + ) + else: + return processors.str_to_time + + +colspecs = { + sqltypes.Date: DATE, + sqltypes.DateTime: DATETIME, + sqltypes.JSON: _SQliteJson, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, + sqltypes.Time: TIME, +} + +ischema_names = { + "BIGINT": sqltypes.BIGINT, + "BLOB": sqltypes.BLOB, + "BOOL": sqltypes.BOOLEAN, + "BOOLEAN": sqltypes.BOOLEAN, + "CHAR": sqltypes.CHAR, + "DATE": sqltypes.DATE, + "DATE_CHAR": sqltypes.DATE, + "DATETIME": sqltypes.DATETIME, + "DATETIME_CHAR": sqltypes.DATETIME, + "DOUBLE": sqltypes.FLOAT, + "DECIMAL": sqltypes.DECIMAL, + "FLOAT": sqltypes.FLOAT, + "INT": sqltypes.INTEGER, + "INTEGER": sqltypes.INTEGER, + "JSON": JSON, + "NUMERIC": sqltypes.NUMERIC, + "REAL": sqltypes.REAL, + "SMALLINT": sqltypes.SMALLINT, + "TEXT": sqltypes.TEXT, + "TIME": sqltypes.TIME, + "TIME_CHAR": sqltypes.TIME, + "TIMESTAMP": sqltypes.TIMESTAMP, + "VARCHAR": sqltypes.VARCHAR, + "NVARCHAR": sqltypes.NVARCHAR, + "NCHAR": sqltypes.NCHAR, +} + + +class SQLiteCompiler(compiler.SQLCompiler): + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + "month": "%m", + "day": "%d", + "year": "%Y", + "second": "%S", + "hour": "%H", + "doy": "%j", + "minute": "%M", + "epoch": "%s", + "dow": "%w", + "week": "%W", + }, + ) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_localtimestamp_func(self, func, **kw): + return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def visit_char_length_func(self, fn, **kw): + return "length%s" % self.function_argspec(fn) + + def visit_cast(self, cast, **kwargs): + if self.dialect.supports_cast: + return super(SQLiteCompiler, self).visit_cast(cast, **kwargs) + else: + return self.process(cast.clause, **kwargs) + + def visit_extract(self, extract, **kw): + try: + return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( + self.extract_map[extract.field], + self.process(extract.expr, **kw), + ) + except KeyError as err: + util.raise_( + exc.CompileError( + "%s is not a valid extract argument." % extract.field + ), + replace_context=err, + ) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += "\n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT " + self.process(sql.literal(-1)) + text += " OFFSET " + self.process(select._offset_clause, **kw) + else: + text += " OFFSET " + self.process(sql.literal(0), **kw) + return text + + def for_update_clause(self, select, **kw): + # sqlite has no "FOR UPDATE" AFAICT + return "" + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "%s IS NOT %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "%s IS %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + if binary.type._type_affinity is sqltypes.JSON: + expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))" + else: + expr = "JSON_EXTRACT(%s, %s)" + + return expr % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + if binary.type._type_affinity is sqltypes.JSON: + expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))" + else: + expr = "JSON_EXTRACT(%s, %s)" + + return expr % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_empty_set_op_expr(self, type_, expand_op): + # slightly old SQLite versions don't seem to be able to handle + # the empty set impl + return self.visit_empty_set_expr(type_) + + def visit_empty_set_expr(self, element_types): + return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % ( + ", ".join("1" for type_ in element_types or [INTEGER()]), + ", ".join("1" for type_ in element_types or [INTEGER()]), + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " REGEXP ", **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) + + def _on_conflict_target(self, clause, **kw): + if clause.constraint_target is not None: + target_text = "(%s)" % clause.constraint_target + elif clause.inferred_target_elements is not None: + target_text = "(%s)" % ", ".join( + ( + self.preparer.quote(c) + if isinstance(c, util.string_types) + else self.process(c, include_table=False, use_schema=False) + ) + for c in clause.inferred_target_elements + ) + if clause.inferred_target_whereclause is not None: + target_text += " WHERE %s" % self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False, + literal_binds=True, + ) + + else: + target_text = "" + + return target_text + + def visit_on_conflict_do_nothing(self, on_conflict, **kw): + + target_text = self._on_conflict_target(on_conflict, **kw) + + if target_text: + return "ON CONFLICT %s DO NOTHING" % target_text + else: + return "ON CONFLICT DO NOTHING" + + def visit_on_conflict_do_update(self, on_conflict, **kw): + clause = on_conflict + + target_text = self._on_conflict_target(on_conflict, **kw) + + action_set_ops = [] + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + + insert_statement = self.stack[-1]["selectable"] + cols = insert_statement.table.c + for c in cols: + col_key = c.key + + if col_key in set_parameters: + value = set_parameters.pop(col_key) + elif c in set_parameters: + value = set_parameters.pop(c) + else: + continue + + if coercions._is_literal(value): + value = elements.BindParameter(None, value, type_=c.type) + + else: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = self.preparer.quote(c.name) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.current_executable.table.name, + (", ".join("'%s'" % c for c in set_parameters)), + ) + ) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, util.string_types) + else self.process(k, use_schema=False) + ) + value_text = self.process( + coercions.expect(roles.ExpressionElementRole, v), + use_schema=False, + ) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + action_text = ", ".join(action_set_ops) + if clause.update_whereclause is not None: + action_text += " WHERE %s" % self.process( + clause.update_whereclause, include_table=True, use_schema=False + ) + + return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) + + +class SQLiteDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + + coltype = self.dialect.type_compiler.process( + column.type, type_expression=column + ) + colspec = self.preparer.format_column(column) + " " + coltype + default = self.get_column_default_string(column) + if default is not None: + if isinstance(column.server_default.arg, ColumnElement): + default = "(" + default + ")" + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_not_null" + ] + if on_conflict_clause is not None: + colspec += " ON CONFLICT " + on_conflict_clause + + if column.primary_key: + if ( + column.autoincrement is True + and len(column.table.primary_key.columns) != 1 + ): + raise exc.CompileError( + "SQLite does not support autoincrement for " + "composite primary keys" + ) + + if ( + column.table.dialect_options["sqlite"]["autoincrement"] + and len(column.table.primary_key.columns) == 1 + and issubclass(column.type._type_affinity, sqltypes.Integer) + and not column.foreign_keys + ): + colspec += " PRIMARY KEY" + + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_primary_key" + ] + if on_conflict_clause is not None: + colspec += " ON CONFLICT " + on_conflict_clause + + colspec += " AUTOINCREMENT" + + if column.computed is not None: + colspec += " " + self.process(column.computed) + + return colspec + + def visit_primary_key_constraint(self, constraint): + # for columns with sqlite_autoincrement=True, + # the PRIMARY KEY constraint can only be inline + # with the column itself. + if len(constraint.columns) == 1: + c = list(constraint)[0] + if ( + c.primary_key + and c.table.dialect_options["sqlite"]["autoincrement"] + and issubclass(c.type._type_affinity, sqltypes.Integer) + and not c.foreign_keys + ): + return None + + text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint( + constraint + ) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + if on_conflict_clause is None and len(constraint.columns) == 1: + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_primary_key" + ] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_unique_constraint(self, constraint): + text = super(SQLiteDDLCompiler, self).visit_unique_constraint( + constraint + ) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + if on_conflict_clause is None and len(constraint.columns) == 1: + col1 = list(constraint)[0] + if isinstance(col1, schema.SchemaItem): + on_conflict_clause = list(constraint)[0].dialect_options[ + "sqlite" + ]["on_conflict_unique"] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_check_constraint(self, constraint): + text = super(SQLiteDDLCompiler, self).visit_check_constraint( + constraint + ) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_column_check_constraint(self, constraint): + text = super(SQLiteDDLCompiler, self).visit_column_check_constraint( + constraint + ) + + if constraint.dialect_options["sqlite"]["on_conflict"] is not None: + raise exc.CompileError( + "SQLite does not support on conflict clause for " + "column check constraint" + ) + + return text + + def visit_foreign_key_constraint(self, constraint): + + local_table = constraint.elements[0].parent.table + remote_table = constraint.elements[0].column.table + + if local_table.schema != remote_table.schema: + return None + else: + return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint( + constraint + ) + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table, use_schema=False) + + def visit_create_index( + self, create, include_schema=False, include_table_schema=True + ): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + + text += "INDEX " + + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=False), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + + whereclause = index.dialect_options["sqlite"]["where"] + if whereclause is not None: + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def post_create_table(self, table): + if table.dialect_options["sqlite"]["with_rowid"] is False: + return "\n WITHOUT ROWID" + return "" + + +class SQLiteTypeCompiler(compiler.GenericTypeCompiler): + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_) + + def visit_DATETIME(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) + else: + return "DATETIME_CHAR" + + def visit_DATE(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super(SQLiteTypeCompiler, self).visit_DATE(type_) + else: + return "DATE_CHAR" + + def visit_TIME(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super(SQLiteTypeCompiler, self).visit_TIME(type_) + else: + return "TIME_CHAR" + + def visit_JSON(self, type_, **kw): + # note this name provides NUMERIC affinity, not TEXT. + # should not be an issue unless the JSON value consists of a single + # numeric value. JSONTEXT can be used if this case is required. + return "JSON" + + +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = set( + [ + "add", + "after", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "attach", + "autoincrement", + "before", + "begin", + "between", + "by", + "cascade", + "case", + "cast", + "check", + "collate", + "column", + "commit", + "conflict", + "constraint", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "database", + "default", + "deferrable", + "deferred", + "delete", + "desc", + "detach", + "distinct", + "drop", + "each", + "else", + "end", + "escape", + "except", + "exclusive", + "exists", + "explain", + "false", + "fail", + "for", + "foreign", + "from", + "full", + "glob", + "group", + "having", + "if", + "ignore", + "immediate", + "in", + "index", + "indexed", + "initially", + "inner", + "insert", + "instead", + "intersect", + "into", + "is", + "isnull", + "join", + "key", + "left", + "like", + "limit", + "match", + "natural", + "not", + "notnull", + "null", + "of", + "offset", + "on", + "or", + "order", + "outer", + "plan", + "pragma", + "primary", + "query", + "raise", + "references", + "reindex", + "rename", + "replace", + "restrict", + "right", + "rollback", + "row", + "select", + "set", + "table", + "temp", + "temporary", + "then", + "to", + "transaction", + "trigger", + "true", + "union", + "unique", + "update", + "using", + "vacuum", + "values", + "view", + "virtual", + "when", + "where", + ] + ) + + +class SQLiteExecutionContext(default.DefaultExecutionContext): + @util.memoized_property + def _preserve_raw_colnames(self): + return ( + not self.dialect._broken_dotted_colnames + or self.execution_options.get("sqlite_raw_colnames", False) + ) + + def _translate_colname(self, colname): + # TODO: detect SQLite version 3.10.0 or greater; + # see [ticket:3633] + + # adjust for dotted column names. SQLite + # in the case of UNION may store col names as + # "tablename.colname", or if using an attached database, + # "database.tablename.colname", in cursor.description + if not self._preserve_raw_colnames and "." in colname: + return colname.split(".")[-1], colname + else: + return colname, None + + +class SQLiteDialect(default.DefaultDialect): + name = "sqlite" + supports_alter = False + supports_unicode_statements = True + supports_unicode_binds = True + + # SQlite supports "DEFAULT VALUES" but *does not* support + # "VALUES (DEFAULT)" + supports_default_values = True + supports_default_metavalue = False + + supports_empty_insert = False + supports_cast = True + supports_multivalues_insert = True + tuple_in_values = True + supports_statement_cache = True + + default_paramstyle = "qmark" + execution_ctx_cls = SQLiteExecutionContext + statement_compiler = SQLiteCompiler + ddl_compiler = SQLiteDDLCompiler + type_compiler = SQLiteTypeCompiler + preparer = SQLiteIdentifierPreparer + ischema_names = ischema_names + colspecs = colspecs + isolation_level = None + + construct_arguments = [ + ( + sa_schema.Table, + { + "autoincrement": False, + "with_rowid": True, + }, + ), + (sa_schema.Index, {"where": None}), + ( + sa_schema.Column, + { + "on_conflict_primary_key": None, + "on_conflict_not_null": None, + "on_conflict_unique": None, + }, + ), + (sa_schema.Constraint, {"on_conflict": None}), + ] + + _broken_fk_pragma_quotes = False + _broken_dotted_colnames = False + + @util.deprecated_params( + _json_serializer=( + "1.3.7", + "The _json_serializer argument to the SQLite dialect has " + "been renamed to the correct name of json_serializer. The old " + "argument name will be removed in a future release.", + ), + _json_deserializer=( + "1.3.7", + "The _json_deserializer argument to the SQLite dialect has " + "been renamed to the correct name of json_deserializer. The old " + "argument name will be removed in a future release.", + ), + ) + def __init__( + self, + isolation_level=None, + native_datetime=False, + json_serializer=None, + json_deserializer=None, + _json_serializer=None, + _json_deserializer=None, + **kwargs + ): + default.DefaultDialect.__init__(self, **kwargs) + self.isolation_level = isolation_level + + if _json_serializer: + json_serializer = _json_serializer + if _json_deserializer: + json_deserializer = _json_deserializer + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + + # this flag used by pysqlite dialect, and perhaps others in the + # future, to indicate the driver is handling date/timestamp + # conversions (and perhaps datetime/time as well on some hypothetical + # driver ?) + self.native_datetime = native_datetime + + if self.dbapi is not None: + if self.dbapi.sqlite_version_info < (3, 7, 16): + util.warn( + "SQLite version %s is older than 3.7.16, and will not " + "support right nested joins, as are sometimes used in " + "more complex ORM scenarios. SQLAlchemy 1.4 and above " + "no longer tries to rewrite these joins." + % (self.dbapi.sqlite_version_info,) + ) + + self._broken_dotted_colnames = self.dbapi.sqlite_version_info < ( + 3, + 10, + 0, + ) + self.supports_default_values = self.dbapi.sqlite_version_info >= ( + 3, + 3, + 8, + ) + self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3) + self.supports_multivalues_insert = ( + # https://www.sqlite.org/releaselog/3_7_11.html + self.dbapi.sqlite_version_info + >= (3, 7, 11) + ) + # see https://www.sqlalchemy.org/trac/ticket/2568 + # as well as https://www.sqlite.org/src/info/600482d161 + self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < ( + 3, + 6, + 14, + ) + + _isolation_lookup = util.immutabledict( + {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0} + ) + + def set_isolation_level(self, connection, level): + try: + isolation_level = self._isolation_lookup[level.replace("_", " ")] + except KeyError as err: + util.raise_( + exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % ( + level, + self.name, + ", ".join(self._isolation_lookup), + ) + ), + replace_context=err, + ) + cursor = connection.cursor() + cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) + cursor.close() + + def get_isolation_level(self, connection): + cursor = connection.cursor() + cursor.execute("PRAGMA read_uncommitted") + res = cursor.fetchone() + if res: + value = res[0] + else: + # https://www.sqlite.org/changes.html#version_3_3_3 + # "Optional READ UNCOMMITTED isolation (instead of the + # default isolation level of SERIALIZABLE) and + # table level locking when database connections + # share a common cache."" + # pre-SQLite 3.3.0 default to 0 + value = 0 + cursor.close() + if value == 0: + return "SERIALIZABLE" + elif value == 1: + return "READ UNCOMMITTED" + else: + assert False, "Unknown isolation level %s" % value + + def on_connect(self): + if self.isolation_level is not None: + + def connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + return connect + else: + return None + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = "PRAGMA database_list" + dl = connection.exec_driver_sql(s) + + return [db[1] for db in dl if db[1] != "temp"] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = "%s.sqlite_master" % qschema + else: + master = "sqlite_master" + s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % ( + master, + ) + rs = connection.exec_driver_sql(s) + return [row[0] for row in rs] + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + s = ( + "SELECT name FROM sqlite_temp_master " + "WHERE type='table' ORDER BY name " + ) + rs = connection.exec_driver_sql(s) + + return [row[0] for row in rs] + + @reflection.cache + def get_temp_view_names(self, connection, **kw): + s = ( + "SELECT name FROM sqlite_temp_master " + "WHERE type='view' ORDER BY name " + ) + rs = connection.exec_driver_sql(s) + + return [row[0] for row in rs] + + def has_table(self, connection, table_name, schema=None): + self._ensure_has_table_connection(connection) + + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema + ) + return bool(info) + + def _get_default_schema_name(self, connection): + return "main" + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = "%s.sqlite_master" % qschema + else: + master = "sqlite_master" + s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % ( + master, + ) + rs = connection.exec_driver_sql(s) + + return [row[0] for row in rs] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = "%s.sqlite_master" % qschema + s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % ( + master, + ) + rs = connection.exec_driver_sql(s, (view_name,)) + else: + try: + s = ( + "SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = ? " + "AND type='view'" + ) + rs = connection.exec_driver_sql(s, (view_name,)) + except exc.DBAPIError: + s = ( + "SELECT sql FROM sqlite_master WHERE name = ? " + "AND type='view'" + ) + rs = connection.exec_driver_sql(s, (view_name,)) + + result = rs.fetchall() + if result: + return result[0].sql + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + pragma = "table_info" + # computed columns are threaded as hidden, they require table_xinfo + if self.server_version_info >= (3, 31): + pragma = "table_xinfo" + info = self._get_table_pragma( + connection, pragma, table_name, schema=schema + ) + columns = [] + tablesql = None + for row in info: + name = row[1] + type_ = row[2].upper() + nullable = not row[3] + default = row[4] + primary_key = row[5] + hidden = row[6] if pragma == "table_xinfo" else 0 + + # hidden has value 0 for normal columns, 1 for hidden columns, + # 2 for computed virtual columns and 3 for computed stored columns + # https://www.sqlite.org/src/info/069351b85f9a706f60d3e98fbc8aaf40c374356b967c0464aede30ead3d9d18b + if hidden == 1: + continue + + generated = bool(hidden) + persisted = hidden == 3 + + if tablesql is None and generated: + tablesql = self._get_table_sql( + connection, table_name, schema, **kw + ) + + columns.append( + self._get_column_info( + name, + type_, + nullable, + default, + primary_key, + generated, + persisted, + tablesql, + ) + ) + return columns + + def _get_column_info( + self, + name, + type_, + nullable, + default, + primary_key, + generated, + persisted, + tablesql, + ): + + if generated: + # the type of a column "cc INTEGER GENERATED ALWAYS AS (1 + 42)" + # somehow is "INTEGER GENERATED ALWAYS" + type_ = re.sub("generated", "", type_, flags=re.IGNORECASE) + type_ = re.sub("always", "", type_, flags=re.IGNORECASE).strip() + + coltype = self._resolve_type_affinity(type_) + + if default is not None: + default = util.text_type(default) + + colspec = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": "auto", + "primary_key": primary_key, + } + if generated: + sqltext = "" + if tablesql: + pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?" + match = re.search( + re.escape(name) + pattern, tablesql, re.IGNORECASE + ) + if match: + sqltext = match.group(1) + colspec["computed"] = {"sqltext": sqltext, "persisted": persisted} + return colspec + + def _resolve_type_affinity(self, type_): + """Return a data type from a reflected column, using affinity rules. + + SQLite's goal for universal compatibility introduces some complexity + during reflection, as a column's defined type might not actually be a + type that SQLite understands - or indeed, my not be defined *at all*. + Internally, SQLite handles this with a 'data type affinity' for each + column definition, mapping to one of 'TEXT', 'NUMERIC', 'INTEGER', + 'REAL', or 'NONE' (raw bits). The algorithm that determines this is + listed in https://www.sqlite.org/datatype3.html section 2.1. + + This method allows SQLAlchemy to support that algorithm, while still + providing access to smarter reflection utilities by recognizing + column definitions that SQLite only supports through affinity (like + DATE and DOUBLE). + + """ + match = re.match(r"([\w ]+)(\(.*?\))?", type_) + if match: + coltype = match.group(1) + args = match.group(2) + else: + coltype = "" + args = "" + + if coltype in self.ischema_names: + coltype = self.ischema_names[coltype] + elif "INT" in coltype: + coltype = sqltypes.INTEGER + elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype: + coltype = sqltypes.TEXT + elif "BLOB" in coltype or not coltype: + coltype = sqltypes.NullType + elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype: + coltype = sqltypes.REAL + else: + coltype = sqltypes.NUMERIC + + if args is not None: + args = re.findall(r"(\d+)", args) + try: + coltype = coltype(*[int(a) for a in args]) + except TypeError: + util.warn( + "Could not instantiate type %s with " + "reflected arguments %s; using no arguments." + % (coltype, args) + ) + coltype = coltype() + else: + coltype = coltype() + + return coltype + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + constraint_name = None + table_data = self._get_table_sql(connection, table_name, schema=schema) + if table_data: + PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY" + result = re.search(PK_PATTERN, table_data, re.I) + constraint_name = result.group(1) if result else None + + cols = self.get_columns(connection, table_name, schema, **kw) + cols.sort(key=lambda col: col.get("primary_key")) + pkeys = [] + for col in cols: + if col["primary_key"]: + pkeys.append(col["name"]) + + return {"constrained_columns": pkeys, "name": constraint_name} + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # sqlite makes this *extremely difficult*. + # First, use the pragma to get the actual FKs. + pragma_fks = self._get_table_pragma( + connection, "foreign_key_list", table_name, schema=schema + ) + + fks = {} + + for row in pragma_fks: + (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) + + if not rcol: + # no referred column, which means it was not named in the + # original DDL. The referred columns of the foreign key + # constraint are therefore the primary key of the referred + # table. + referred_pk = self.get_pk_constraint( + connection, rtbl, schema=schema, **kw + ) + # note that if table doesn't exist, we still get back a record, + # just it has no columns in it + referred_columns = referred_pk["constrained_columns"] + else: + # note we use this list only if this is the first column + # in the constraint. for subsequent columns we ignore the + # list and append "rcol" if present. + referred_columns = [] + + if self._broken_fk_pragma_quotes: + rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl) + + if numerical_id in fks: + fk = fks[numerical_id] + else: + fk = fks[numerical_id] = { + "name": None, + "constrained_columns": [], + "referred_schema": schema, + "referred_table": rtbl, + "referred_columns": referred_columns, + "options": {}, + } + fks[numerical_id] = fk + + fk["constrained_columns"].append(lcol) + + if rcol: + fk["referred_columns"].append(rcol) + + def fk_sig(constrained_columns, referred_table, referred_columns): + return ( + tuple(constrained_columns) + + (referred_table,) + + tuple(referred_columns) + ) + + # then, parse the actual SQL and attempt to find DDL that matches + # the names as well. SQLite saves the DDL in whatever format + # it was typed in as, so need to be liberal here. + + keys_by_signature = dict( + ( + fk_sig( + fk["constrained_columns"], + fk["referred_table"], + fk["referred_columns"], + ), + fk, + ) + for fk in fks.values() + ) + + table_data = self._get_table_sql(connection, table_name, schema=schema) + if table_data is None: + # system tables, etc. + return [] + + def parse_fks(): + FK_PATTERN = ( + r"(?:CONSTRAINT (\w+) +)?" + r"FOREIGN KEY *\( *(.+?) *\) +" + r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\) *' + r"((?:ON (?:DELETE|UPDATE) " + r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)" + ) + for match in re.finditer(FK_PATTERN, table_data, re.I): + ( + constraint_name, + constrained_columns, + referred_quoted_name, + referred_name, + referred_columns, + onupdatedelete, + ) = match.group(1, 2, 3, 4, 5, 6) + constrained_columns = list( + self._find_cols_in_sig(constrained_columns) + ) + if not referred_columns: + referred_columns = constrained_columns + else: + referred_columns = list( + self._find_cols_in_sig(referred_columns) + ) + referred_name = referred_quoted_name or referred_name + options = {} + + for token in re.split(r" *\bON\b *", onupdatedelete.upper()): + if token.startswith("DELETE"): + ondelete = token[6:].strip() + if ondelete and ondelete != "NO ACTION": + options["ondelete"] = ondelete + elif token.startswith("UPDATE"): + onupdate = token[6:].strip() + if onupdate and onupdate != "NO ACTION": + options["onupdate"] = onupdate + yield ( + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) + + fkeys = [] + + for ( + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) in parse_fks(): + sig = fk_sig(constrained_columns, referred_name, referred_columns) + if sig not in keys_by_signature: + util.warn( + "WARNING: SQL-parsed foreign key constraint " + "'%s' could not be located in PRAGMA " + "foreign_keys for table %s" % (sig, table_name) + ) + continue + key = keys_by_signature.pop(sig) + key["name"] = constraint_name + key["options"] = options + fkeys.append(key) + # assume the remainders are the unnamed, inline constraints, just + # use them as is as it's extremely difficult to parse inline + # constraints + fkeys.extend(keys_by_signature.values()) + return fkeys + + def _find_cols_in_sig(self, sig): + for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): + yield match.group(1) or match.group(2) + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + + auto_index_by_sig = {} + for idx in self.get_indexes( + connection, + table_name, + schema=schema, + include_auto_indexes=True, + **kw + ): + if not idx["name"].startswith("sqlite_autoindex"): + continue + sig = tuple(idx["column_names"]) + auto_index_by_sig[sig] = idx + + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw + ) + if not table_data: + return [] + + unique_constraints = [] + + def parse_uqs(): + UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' + INLINE_UNIQUE_PATTERN = ( + r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' + r"+[a-z0-9_ ]+? +UNIQUE" + ) + + for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): + name, cols = match.group(1, 2) + yield name, list(self._find_cols_in_sig(cols)) + + # we need to match inlines as well, as we seek to differentiate + # a UNIQUE constraint from a UNIQUE INDEX, even though these + # are kind of the same thing :) + for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): + cols = list( + self._find_cols_in_sig(match.group(1) or match.group(2)) + ) + yield None, cols + + for name, cols in parse_uqs(): + sig = tuple(cols) + if sig in auto_index_by_sig: + auto_index_by_sig.pop(sig) + parsed_constraint = {"name": name, "column_names": cols} + unique_constraints.append(parsed_constraint) + # NOTE: auto_index_by_sig might not be empty here, + # the PRIMARY KEY may have an entry. + return unique_constraints + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw + ) + if not table_data: + return [] + + CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" + check_constraints = [] + # NOTE: we aren't using re.S here because we actually are + # taking advantage of each CHECK constraint being all on one + # line in the table definition in order to delineate. This + # necessarily makes assumptions as to how the CREATE TABLE + # was emitted. + + for match in re.finditer(CHECK_PATTERN, table_data, re.I): + name = match.group(1) + + if name: + name = re.sub(r'^"|"$', "", name) + + check_constraints.append({"sqltext": match.group(2), "name": name}) + + return check_constraints + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + pragma_indexes = self._get_table_pragma( + connection, "index_list", table_name, schema=schema + ) + indexes = [] + + include_auto_indexes = kw.pop("include_auto_indexes", False) + for row in pragma_indexes: + # ignore implicit primary key index. + # https://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html + if not include_auto_indexes and row[1].startswith( + "sqlite_autoindex" + ): + continue + indexes.append(dict(name=row[1], column_names=[], unique=row[2])) + + # loop thru unique indexes to get the column names. + for idx in list(indexes): + pragma_index = self._get_table_pragma( + connection, "index_info", idx["name"] + ) + + for row in pragma_index: + if row[2] is None: + util.warn( + "Skipped unsupported reflection of " + "expression-based index %s" % idx["name"] + ) + indexes.remove(idx) + break + else: + idx["column_names"].append(row[2]) + return indexes + + @reflection.cache + def _get_table_sql(self, connection, table_name, schema=None, **kw): + if schema: + schema_expr = "%s." % ( + self.identifier_preparer.quote_identifier(schema) + ) + else: + schema_expr = "" + try: + s = ( + "SELECT sql FROM " + " (SELECT * FROM %(schema)ssqlite_master UNION ALL " + " SELECT * FROM %(schema)ssqlite_temp_master) " + "WHERE name = ? " + "AND type = 'table'" % {"schema": schema_expr} + ) + rs = connection.exec_driver_sql(s, (table_name,)) + except exc.DBAPIError: + s = ( + "SELECT sql FROM %(schema)ssqlite_master " + "WHERE name = ? " + "AND type = 'table'" % {"schema": schema_expr} + ) + rs = connection.exec_driver_sql(s, (table_name,)) + return rs.scalar() + + def _get_table_pragma(self, connection, pragma, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + statements = ["PRAGMA %s." % quote(schema)] + else: + # because PRAGMA looks in all attached databases if no schema + # given, need to specify "main" schema, however since we want + # 'temp' tables in the same namespace as 'main', need to run + # the PRAGMA twice + statements = ["PRAGMA main.", "PRAGMA temp."] + + qtable = quote(table_name) + for statement in statements: + statement = "%s%s(%s)" % (statement, pragma, qtable) + cursor = connection.exec_driver_sql(statement) + if not cursor._soft_closed: + # work around SQLite issue whereby cursor.description + # is blank when PRAGMA returns no rows: + # https://www.sqlite.org/cvstrac/tktview?tn=1884 + result = cursor.fetchall() + else: + result = [] + if result: + return result + else: + return [] diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py new file mode 100644 index 0000000..b04a5e6 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -0,0 +1,200 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from ... import util +from ...sql import coercions +from ...sql import roles +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.expression import alias +from ...util.langhelpers import public_factory + + +__all__ = ("Insert", "insert") + + +class Insert(StandardInsert): + """SQLite-specific implementation of INSERT. + + Adds methods for SQLite-specific syntaxes such as ON CONFLICT. + + The :class:`_sqlite.Insert` object is created using the + :func:`sqlalchemy.dialects.sqlite.insert` function. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`sqlite_on_conflict_insert` + + """ + + stringify_dialect = "sqlite" + inherit_cache = False + + @util.memoized_property + def excluded(self): + """Provide the ``excluded`` namespace for an ON CONFLICT statement + + SQLite's ON CONFLICT clause allows reference to the row that would + be inserted, known as ``excluded``. This attribute provides + all columns in this row to be referenceable. + + .. tip:: The :attr:`_sqlite.Insert.excluded` attribute is an instance + of :class:`_expression.ColumnCollection`, which provides an + interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.excluded.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.excluded["column name"]`` or + ``stmt.excluded["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + """ + return alias(self.table, name="excluded").columns + + _on_conflict_exclusive = _exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already has " + "an ON CONFLICT clause established" + }, + ) + + @_generative + @_on_conflict_exclusive + def on_conflict_do_update( + self, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): + r""" + Specifies a DO UPDATE SET action for ON CONFLICT clause. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index or unique constraint. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + :param set\_: + A dictionary or other mapping object + where the keys are either names of columns in the target table, + or :class:`_schema.Column` objects or other ORM-mapped columns + matching that of the target table, and expressions or literals + as values, specifying the ``SET`` actions to take. + + .. versionadded:: 1.4 The + :paramref:`_sqlite.Insert.on_conflict_do_update.set_` + parameter supports :class:`_schema.Column` objects from the target + :class:`_schema.Table` as keys. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of + UPDATE, unless they are manually specified in the + :paramref:`.Insert.on_conflict_do_update.set_` dictionary. + + :param where: + Optional argument. If present, can be a literal SQL + string or an acceptable expression for a ``WHERE`` clause + that restricts the rows affected by ``DO UPDATE SET``. Rows + not meeting the ``WHERE`` condition will not be updated + (effectively a ``DO NOTHING`` for those rows). + + """ + + self._post_values_clause = OnConflictDoUpdate( + index_elements, index_where, set_, where + ) + + @_generative + @_on_conflict_exclusive + def on_conflict_do_nothing(self, index_elements=None, index_where=None): + """ + Specifies a DO NOTHING action for ON CONFLICT clause. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index or unique constraint. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + """ + + self._post_values_clause = OnConflictDoNothing( + index_elements, index_where + ) + + +insert = public_factory( + Insert, ".dialects.sqlite.insert", ".dialects.sqlite.Insert" +) + + +class OnConflictClause(ClauseElement): + stringify_dialect = "sqlite" + + def __init__(self, index_elements=None, index_where=None): + + if index_elements is not None: + self.constraint_target = None + self.inferred_target_elements = index_elements + self.inferred_target_whereclause = index_where + else: + self.constraint_target = ( + self.inferred_target_elements + ) = self.inferred_target_whereclause = None + + +class OnConflictDoNothing(OnConflictClause): + __visit_name__ = "on_conflict_do_nothing" + + +class OnConflictDoUpdate(OnConflictClause): + __visit_name__ = "on_conflict_do_update" + + def __init__( + self, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): + super(OnConflictDoUpdate, self).__init__( + index_elements=index_elements, + index_where=index_where, + ) + + if isinstance(set_, dict): + if not set_: + raise ValueError("set parameter dictionary must not be empty") + elif isinstance(set_, ColumnCollection): + set_ = dict(set_) + else: + raise ValueError( + "set parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update_values_to_set = [ + (coercions.expect(roles.DMLColumnRole, key), value) + for key, value in set_.items() + ] + self.update_whereclause = where diff --git a/lib/sqlalchemy/dialects/sqlite/json.py b/lib/sqlalchemy/dialects/sqlite/json.py new file mode 100644 index 0000000..614f954 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/json.py @@ -0,0 +1,84 @@ +from ... import types as sqltypes + + +class JSON(sqltypes.JSON): + """SQLite JSON type. + + SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note + that JSON1_ is a + `loadable extension <https://www.sqlite.org/loadext.html>`_ and as such + may not be available, or may require run-time loading. + + :class:`_sqlite.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a SQLite backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`_sqlite.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_EXTRACT`` + function wrapped in the ``JSON_QUOTE`` function at the database level. + Extracted values are quoted in order to ensure that the results are + always JSON string values. + + + .. versionadded:: 1.3 + + + .. _JSON1: https://www.sqlite.org/json1.html + + """ + + +# Note: these objects currently match exactly those of MySQL, however since +# these are not generalizable to all JSON implementations, remain separately +# implemented for each dialect. +class _FormatTypeMixin(object): + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/lib/sqlalchemy/dialects/sqlite/provision.py b/lib/sqlalchemy/dialects/sqlite/provision.py new file mode 100644 index 0000000..e5b17e8 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/provision.py @@ -0,0 +1,142 @@ +import os +import re + +from ... import exc +from ...engine import url as sa_url +from ...testing.provision import create_db +from ...testing.provision import drop_db +from ...testing.provision import follower_url_from_main +from ...testing.provision import generate_driver_url +from ...testing.provision import log +from ...testing.provision import post_configure_engine +from ...testing.provision import run_reap_dbs +from ...testing.provision import stop_test_class_outside_fixtures +from ...testing.provision import temp_table_keyword_args + + +# TODO: I can't get this to build dynamically with pytest-xdist procs +_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"} + + +@generate_driver_url.for_db("sqlite") +def generate_driver_url(url, driver, query_str): + if driver == "pysqlcipher" and url.get_driver_name() != "pysqlcipher": + if url.database: + url = url.set(database=url.database + ".enc") + url = url.set(password="test") + url = url.set(drivername="sqlite+%s" % (driver,)) + try: + url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return url + + +@follower_url_from_main.for_db("sqlite") +def _sqlite_follower_url_from_main(url, ident): + url = sa_url.make_url(url) + + if not url.database or url.database == ":memory:": + return url + else: + + m = re.match(r"(.+?)\.(.+)$", url.database) + name, ext = m.group(1, 2) + drivername = url.get_driver_name() + return sa_url.make_url( + "sqlite+%s:///%s_%s.%s" % (drivername, drivername, ident, ext) + ) + + +@post_configure_engine.for_db("sqlite") +def _sqlite_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + @event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + # use file DBs in all cases, memory acts kind of strangely + # as an attached + if not follower_ident: + # note this test_schema.db gets created for all test runs. + # there's not any dedicated cleanup step for it. it in some + # ways corresponds to the "test.test_schema" schema that's + # expected to be already present, so for now it just stays + # in a given checkout directory. + dbapi_connection.execute( + 'ATTACH DATABASE "%s_test_schema.db" AS test_schema' + % (engine.driver,) + ) + else: + dbapi_connection.execute( + 'ATTACH DATABASE "%s_%s_test_schema.db" AS test_schema' + % (follower_ident, engine.driver) + ) + + +@create_db.for_db("sqlite") +def _sqlite_create_db(cfg, eng, ident): + pass + + +@drop_db.for_db("sqlite") +def _sqlite_drop_db(cfg, eng, ident): + for path in [ + "%s.db" % ident, + "%s_%s_test_schema.db" % (ident, eng.driver), + ]: + if os.path.exists(path): + log.info("deleting SQLite database file: %s" % path) + os.remove(path) + + +@stop_test_class_outside_fixtures.for_db("sqlite") +def stop_test_class_outside_fixtures(config, db, cls): + with db.connect() as conn: + files = [ + row.file + for row in conn.exec_driver_sql("PRAGMA database_list") + if row.file + ] + + if files: + db.dispose() + # some sqlite file tests are not cleaning up well yet, so do this + # just to make things simple for now + for file_ in files: + if file_ and os.path.exists(file_): + os.remove(file_) + + +@temp_table_keyword_args.for_db("sqlite") +def _sqlite_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} + + +@run_reap_dbs.for_db("sqlite") +def _reap_sqlite_dbs(url, idents): + log.info("db reaper connecting to %r", url) + + log.info("identifiers in file: %s", ", ".join(idents)) + for ident in idents: + # we don't have a config so we can't call _sqlite_drop_db due to the + # decorator + for ext in ("db", "db.enc"): + for path in ( + ["%s.%s" % (ident, ext)] + + [ + "%s_%s.%s" % (drivername, ident, ext) + for drivername in _drivernames + ] + + [ + "%s_test_schema.%s" % (drivername, ext) + for drivername in _drivernames + ] + + [ + "%s_%s_test_schema.%s" % (ident, drivername, ext) + for drivername in _drivernames + ] + ): + if os.path.exists(path): + log.info("deleting SQLite database file: %s" % path) + os.remove(path) diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py new file mode 100644 index 0000000..65f94c8 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -0,0 +1,164 @@ +# sqlite/pysqlcipher.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sqlite+pysqlcipher + :name: pysqlcipher + :dbapi: sqlcipher 3 or pysqlcipher + :connectstring: sqlite+pysqlcipher://:passphrase@/file_path[?kdf_iter=<iter>] + + Dialect for support of DBAPIs that make use of the + `SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend. + + +Driver +------ + +Current dialect selection logic is: + +* If the :paramref:`_sa.create_engine.module` parameter supplies a DBAPI module, + that module is used. +* Otherwise for Python 3, choose https://pypi.org/project/sqlcipher3/ +* If not available, fall back to https://pypi.org/project/pysqlcipher3/ +* For Python 2, https://pypi.org/project/pysqlcipher/ is used. + +.. warning:: The ``pysqlcipher3`` and ``pysqlcipher`` DBAPI drivers are no + longer maintained; the ``sqlcipher3`` driver as of this writing appears + to be current. For future compatibility, any pysqlcipher-compatible DBAPI + may be used as follows:: + + import sqlcipher_compatible_driver + + from sqlalchemy import create_engine + + e = create_engine( + "sqlite+pysqlcipher://:password@/dbname.db", + module=sqlcipher_compatible_driver + ) + +These drivers make use of the SQLCipher engine. This system essentially +introduces new PRAGMA commands to SQLite which allows the setting of a +passphrase and other encryption parameters, allowing the database file to be +encrypted. + + +Connect Strings +--------------- + +The format of the connect string is in every way the same as that +of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the +"password" field is now accepted, which should contain a passphrase:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db') + +For an absolute file path, two leading slashes should be used for the +database name:: + + e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db') + +A selection of additional encryption-related pragmas supported by SQLCipher +as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed +in the query string, and will result in that PRAGMA being called for each +new connection. Currently, ``cipher``, ``kdf_iter`` +``cipher_page_size`` and ``cipher_use_hmac`` are supported:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000') + +.. warning:: Previous versions of sqlalchemy did not take into consideration + the encryption-related pragmas passed in the url string, that were silently + ignored. This may cause errors when opening files saved by a + previous sqlalchemy version if the encryption options do not match. + + +Pooling Behavior +---------------- + +The driver makes a change to the default pool behavior of pysqlite +as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver +has been observed to be significantly slower on connection than the +pysqlite driver, most likely due to the encryption overhead, so the +dialect here defaults to using the :class:`.SingletonThreadPool` +implementation, +instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool +implementation is entirely configurable using the +:paramref:`_sa.create_engine.poolclass` parameter; the :class:`. +StaticPool` may +be more feasible for single-threaded use, or :class:`.NullPool` may be used +to prevent unencrypted connections from being held open for long periods of +time, at the expense of slower startup time for new connections. + + +""" # noqa + +from __future__ import absolute_import + +from .pysqlite import SQLiteDialect_pysqlite +from ... import pool +from ... import util + + +class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): + driver = "pysqlcipher" + supports_statement_cache = True + + pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac") + + @classmethod + def dbapi(cls): + if util.py3k: + try: + import sqlcipher3 as sqlcipher + except ImportError: + pass + else: + return sqlcipher + + from pysqlcipher3 import dbapi2 as sqlcipher + + else: + from pysqlcipher import dbapi2 as sqlcipher + + return sqlcipher + + @classmethod + def get_pool_class(cls, url): + return pool.SingletonThreadPool + + def on_connect_url(self, url): + super_on_connect = super( + SQLiteDialect_pysqlcipher, self + ).on_connect_url(url) + + # pull the info we need from the URL early. Even though URL + # is immutable, we don't want any in-place changes to the URL + # to affect things + passphrase = url.password or "" + url_query = dict(url.query) + + def on_connect(conn): + cursor = conn.cursor() + cursor.execute('pragma key="%s"' % passphrase) + for prag in self.pragmas: + value = url_query.get(prag, None) + if value is not None: + cursor.execute('pragma %s="%s"' % (prag, value)) + cursor.close() + + if super_on_connect: + super_on_connect(conn) + + return on_connect + + def create_connect_args(self, url): + plain_url = url._replace(password=None) + plain_url = plain_url.difference_update_query(self.pragmas) + return super(SQLiteDialect_pysqlcipher, self).create_connect_args( + plain_url + ) + + +dialect = SQLiteDialect_pysqlcipher diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py new file mode 100644 index 0000000..1aae561 --- /dev/null +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -0,0 +1,613 @@ +# sqlite/pysqlite.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" +.. dialect:: sqlite+pysqlite + :name: pysqlite + :dbapi: sqlite3 + :connectstring: sqlite+pysqlite:///file_path + :url: https://docs.python.org/library/sqlite3.html + + Note that ``pysqlite`` is the same driver as the ``sqlite3`` + module included with the Python distribution. + +Driver +------ + +The ``sqlite3`` Python DBAPI is standard on all modern Python versions; +for cPython and Pypy, no additional installation is necessary. + + +Connect Strings +--------------- + +The file specification for the SQLite database is taken as the "database" +portion of the URL. Note that the format of a SQLAlchemy url is:: + + driver://user:pass@host/database + +This means that the actual filename to be used starts with the characters to +the **right** of the third slash. So connecting to a relative filepath +looks like:: + + # relative path + e = create_engine('sqlite:///path/to/database.db') + +An absolute path, which is denoted by starting with a slash, means you +need **four** slashes:: + + # absolute path + e = create_engine('sqlite:////path/to/database.db') + +To use a Windows path, regular drive specifications and backslashes can be +used. Double backslashes are probably needed:: + + # absolute path on Windows + e = create_engine('sqlite:///C:\\path\\to\\database.db') + +The sqlite ``:memory:`` identifier is the default if no filepath is +present. Specify ``sqlite://`` and nothing else:: + + # in-memory database + e = create_engine('sqlite://') + +.. _pysqlite_uri_connections: + +URI Connections +^^^^^^^^^^^^^^^ + +Modern versions of SQLite support an alternative system of connecting using a +`driver level URI <https://www.sqlite.org/uri.html>`_, which has the advantage +that additional driver-level arguments can be passed including options such as +"read only". The Python sqlite3 driver supports this mode under modern Python +3 versions. The SQLAlchemy pysqlite driver supports this mode of use by +specifying "uri=true" in the URL query string. The SQLite-level "URI" is kept +as the "database" portion of the SQLAlchemy url (that is, following a slash):: + + e = create_engine("sqlite:///file:path/to/database?mode=ro&uri=true") + +.. note:: The "uri=true" parameter must appear in the **query string** + of the URL. It will not currently work as expected if it is only + present in the :paramref:`_sa.create_engine.connect_args` + parameter dictionary. + +The logic reconciles the simultaneous presence of SQLAlchemy's query string and +SQLite's query string by separating out the parameters that belong to the +Python sqlite3 driver vs. those that belong to the SQLite URI. This is +achieved through the use of a fixed list of parameters known to be accepted by +the Python side of the driver. For example, to include a URL that indicates +the Python sqlite3 "timeout" and "check_same_thread" parameters, along with the +SQLite "mode" and "nolock" parameters, they can all be passed together on the +query string:: + + e = create_engine( + "sqlite:///file:path/to/database?" + "check_same_thread=true&timeout=10&mode=ro&nolock=1&uri=true" + ) + +Above, the pysqlite / sqlite3 DBAPI would be passed arguments as:: + + sqlite3.connect( + "file:path/to/database?mode=ro&nolock=1", + check_same_thread=True, timeout=10, uri=True + ) + +Regarding future parameters added to either the Python or native drivers. new +parameter names added to the SQLite URI scheme should be automatically +accommodated by this scheme. New parameter names added to the Python driver +side can be accommodated by specifying them in the +:paramref:`_sa.create_engine.connect_args` dictionary, +until dialect support is +added by SQLAlchemy. For the less likely case that the native SQLite driver +adds a new parameter name that overlaps with one of the existing, known Python +driver parameters (such as "timeout" perhaps), SQLAlchemy's dialect would +require adjustment for the URL scheme to continue to support this. + +As is always the case for all SQLAlchemy dialects, the entire "URL" process +can be bypassed in :func:`_sa.create_engine` through the use of the +:paramref:`_sa.create_engine.creator` +parameter which allows for a custom callable +that creates a Python sqlite3 driver level connection directly. + +.. versionadded:: 1.3.9 + +.. seealso:: + + `Uniform Resource Identifiers <https://www.sqlite.org/uri.html>`_ - in + the SQLite documentation + +.. _pysqlite_regexp: + +Regular Expression Support +--------------------------- + +.. versionadded:: 1.4 + +Support for the :meth:`_sql.ColumnOperators.regexp_match` operator is provided +using Python's re.search_ function. SQLite itself does not include a working +regular expression operator; instead, it includes a non-implemented placeholder +operator ``REGEXP`` that calls a user-defined function that must be provided. + +SQLAlchemy's implementation makes use of the pysqlite create_function_ hook +as follows:: + + + def regexp(a, b): + return re.search(a, b) is not None + + sqlite_connection.create_function( + "regexp", 2, regexp, + ) + +There is currently no support for regular expression flags as a separate +argument, as these are not supported by SQLite's REGEXP operator, however these +may be included inline within the regular expression string. See `Python regular expressions`_ for +details. + +.. seealso:: + + `Python regular expressions`_: Documentation for Python's regular expression syntax. + +.. _create_function: https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function + +.. _re.search: https://docs.python.org/3/library/re.html#re.search + +.. _Python regular expressions: https://docs.python.org/3/library/re.html#re.search + + + +Compatibility with sqlite3 "native" date and datetime types +----------------------------------------------------------- + +The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and +sqlite3.PARSE_COLNAMES options, which have the effect of any column +or expression explicitly cast as "date" or "timestamp" will be converted +to a Python date or datetime object. The date and datetime types provided +with the pysqlite dialect are not currently compatible with these options, +since they render the ISO date/datetime including microseconds, which +pysqlite's driver does not. Additionally, SQLAlchemy does not at +this time automatically render the "cast" syntax required for the +freestanding functions "current_timestamp" and "current_date" to return +datetime/date types natively. Unfortunately, pysqlite +does not provide the standard DBAPI types in ``cursor.description``, +leaving SQLAlchemy with no way to detect these types on the fly +without expensive per-row type checks. + +Keeping in mind that pysqlite's parsing option is not recommended, +nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES +can be forced if one configures "native_datetime=True" on create_engine():: + + engine = create_engine('sqlite://', + connect_args={'detect_types': + sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES}, + native_datetime=True + ) + +With this flag enabled, the DATE and TIMESTAMP types (but note - not the +DATETIME or TIME types...confused yet ?) will not perform any bind parameter +or result processing. Execution of "func.current_date()" will return a string. +"func.current_timestamp()" is registered as returning a DATETIME type in +SQLAlchemy, so this function still receives SQLAlchemy-level result +processing. + +.. _pysqlite_threading_pooling: + +Threading/Pooling Behavior +--------------------------- + +Pysqlite's default behavior is to prohibit the usage of a single connection +in more than one thread. This is originally intended to work with older +versions of SQLite that did not support multithreaded operation under +various circumstances. In particular, older SQLite versions +did not allow a ``:memory:`` database to be used in multiple threads +under any circumstances. + +Pysqlite does include a now-undocumented flag known as +``check_same_thread`` which will disable this check, however note that +pysqlite connections are still not safe to use in concurrently in multiple +threads. In particular, any statement execution calls would need to be +externally mutexed, as Pysqlite does not provide for thread-safe propagation +of error messages among other things. So while even ``:memory:`` databases +can be shared among threads in modern SQLite, Pysqlite doesn't provide enough +thread-safety to make this usage worth it. + +SQLAlchemy sets up pooling to work with Pysqlite's default behavior: + +* When a ``:memory:`` SQLite database is specified, the dialect by default + will use :class:`.SingletonThreadPool`. This pool maintains a single + connection per thread, so that all access to the engine within the current + thread use the same ``:memory:`` database - other threads would access a + different ``:memory:`` database. +* When a file-based database is specified, the dialect will use + :class:`.NullPool` as the source of connections. This pool closes and + discards connections which are returned to the pool immediately. SQLite + file-based connections have extremely low overhead, so pooling is not + necessary. The scheme also prevents a connection from being used again in + a different thread and works best with SQLite's coarse-grained file locking. + +Using a Memory Database in Multiple Threads +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To use a ``:memory:`` database in a multithreaded scenario, the same +connection object must be shared among threads, since the database exists +only within the scope of that connection. The +:class:`.StaticPool` implementation will maintain a single connection +globally, and the ``check_same_thread`` flag can be passed to Pysqlite +as ``False``:: + + from sqlalchemy.pool import StaticPool + engine = create_engine('sqlite://', + connect_args={'check_same_thread':False}, + poolclass=StaticPool) + +Note that using a ``:memory:`` database in multiple threads requires a recent +version of SQLite. + +Using Temporary Tables with SQLite +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Due to the way SQLite deals with temporary tables, if you wish to use a +temporary table in a file-based SQLite database across multiple checkouts +from the connection pool, such as when using an ORM :class:`.Session` where +the temporary table should continue to remain after :meth:`.Session.commit` or +:meth:`.Session.rollback` is called, a pool which maintains a single +connection must be used. Use :class:`.SingletonThreadPool` if the scope is +only needed within the current thread, or :class:`.StaticPool` is scope is +needed within multiple threads for this case:: + + # maintain the same connection per thread + from sqlalchemy.pool import SingletonThreadPool + engine = create_engine('sqlite:///mydb.db', + poolclass=SingletonThreadPool) + + + # maintain the same connection across all threads + from sqlalchemy.pool import StaticPool + engine = create_engine('sqlite:///mydb.db', + poolclass=StaticPool) + +Note that :class:`.SingletonThreadPool` should be configured for the number +of threads that are to be used; beyond that number, connections will be +closed out in a non deterministic way. + +Unicode +------- + +The pysqlite driver only returns Python ``unicode`` objects in result sets, +never plain strings, and accommodates ``unicode`` objects within bound +parameter values in all cases. Regardless of the SQLAlchemy string type in +use, string-based result values will by Python ``unicode`` in Python 2. +The :class:`.Unicode` type should still be used to indicate those columns that +require unicode, however, so that non-``unicode`` values passed inadvertently +will emit a warning. Pysqlite will emit an error if a non-``unicode`` string +is passed containing non-ASCII characters. + +Dealing with Mixed String / Binary Columns in Python 3 +------------------------------------------------------ + +The SQLite database is weakly typed, and as such it is possible when using +binary values, which in Python 3 are represented as ``b'some string'``, that a +particular SQLite database can have data values within different rows where +some of them will be returned as a ``b''`` value by the Pysqlite driver, and +others will be returned as Python strings, e.g. ``''`` values. This situation +is not known to occur if the SQLAlchemy :class:`.LargeBinary` datatype is used +consistently, however if a particular SQLite database has data that was +inserted using the Pysqlite driver directly, or when using the SQLAlchemy +:class:`.String` type which was later changed to :class:`.LargeBinary`, the +table will not be consistently readable because SQLAlchemy's +:class:`.LargeBinary` datatype does not handle strings so it has no way of +"encoding" a value that is in string format. + +To deal with a SQLite table that has mixed string / binary data in the +same column, use a custom type that will check each row individually:: + + # note this is Python 3 only + + from sqlalchemy import String + from sqlalchemy import TypeDecorator + + class MixedBinary(TypeDecorator): + impl = String + cache_ok = True + + def process_result_value(self, value, dialect): + if isinstance(value, str): + value = bytes(value, 'utf-8') + elif value is not None: + value = bytes(value) + + return value + +Then use the above ``MixedBinary`` datatype in the place where +:class:`.LargeBinary` would normally be used. + +.. _pysqlite_serializable: + +Serializable isolation / Savepoints / Transactional DDL +------------------------------------------------------- + +In the section :ref:`sqlite_concurrency`, we refer to the pysqlite +driver's assortment of issues that prevent several features of SQLite +from working correctly. The pysqlite DBAPI driver has several +long-standing bugs which impact the correctness of its transactional +behavior. In its default mode of operation, SQLite features such as +SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are +non-functional, and in order to use these features, workarounds must +be taken. + +The issue is essentially that the driver attempts to second-guess the user's +intent, failing to start transactions and sometimes ending them prematurely, in +an effort to minimize the SQLite databases's file locking behavior, even +though SQLite itself uses "shared" locks for read-only activities. + +SQLAlchemy chooses to not alter this behavior by default, as it is the +long-expected behavior of the pysqlite driver; if and when the pysqlite +driver attempts to repair these issues, that will be more of a driver towards +defaults for SQLAlchemy. + +The good news is that with a few events, we can implement transactional +support fully, by disabling pysqlite's feature entirely and emitting BEGIN +ourselves. This is achieved using two event listeners:: + + from sqlalchemy import create_engine, event + + engine = create_engine("sqlite:///myfile.db") + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN") + +.. warning:: When using the above recipe, it is advised to not use the + :paramref:`.Connection.execution_options.isolation_level` setting on + :class:`_engine.Connection` and :func:`_sa.create_engine` + with the SQLite driver, + as this function necessarily will also alter the ".isolation_level" setting. + + +Above, we intercept a new pysqlite connection and disable any transactional +integration. Then, at the point at which SQLAlchemy knows that transaction +scope is to begin, we emit ``"BEGIN"`` ourselves. + +When we take control of ``"BEGIN"``, we can also control directly SQLite's +locking modes, introduced at +`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_, +by adding the desired locking mode to our ``"BEGIN"``:: + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN EXCLUSIVE") + +.. seealso:: + + `BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_ - + on the SQLite site + + `sqlite3 SELECT does not BEGIN a transaction <https://bugs.python.org/issue9924>`_ - + on the Python bug tracker + + `sqlite3 module breaks transactions and potentially corrupts data <https://bugs.python.org/issue10740>`_ - + on the Python bug tracker + + +""" # noqa + +import os +import re + +from .base import DATE +from .base import DATETIME +from .base import SQLiteDialect +from ... import exc +from ... import pool +from ... import types as sqltypes +from ... import util + + +class _SQLite_pysqliteTimeStamp(DATETIME): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATETIME.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATETIME.result_processor(self, dialect, coltype) + + +class _SQLite_pysqliteDate(DATE): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATE.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATE.result_processor(self, dialect, coltype) + + +class SQLiteDialect_pysqlite(SQLiteDialect): + default_paramstyle = "qmark" + supports_statement_cache = True + + colspecs = util.update_copy( + SQLiteDialect.colspecs, + { + sqltypes.Date: _SQLite_pysqliteDate, + sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp, + }, + ) + + if not util.py2k: + description_encoding = None + + driver = "pysqlite" + + @classmethod + def dbapi(cls): + if util.py2k: + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError: + try: + from sqlite3 import dbapi2 as sqlite + except ImportError as e: + raise e + else: + from sqlite3 import dbapi2 as sqlite + return sqlite + + @classmethod + def _is_url_file_db(cls, url): + if (url.database and url.database != ":memory:") and ( + url.query.get("mode", None) != "memory" + ): + return True + else: + return False + + @classmethod + def get_pool_class(cls, url): + if cls._is_url_file_db(url): + return pool.NullPool + else: + return pool.SingletonThreadPool + + def _get_server_version_info(self, connection): + return self.dbapi.sqlite_version_info + + _isolation_lookup = SQLiteDialect._isolation_lookup.union( + { + "AUTOCOMMIT": None, + } + ) + + def set_isolation_level(self, connection, level): + if hasattr(connection, "dbapi_connection"): + dbapi_connection = connection.dbapi_connection + else: + dbapi_connection = connection + + if level == "AUTOCOMMIT": + dbapi_connection.isolation_level = None + else: + dbapi_connection.isolation_level = "" + return super(SQLiteDialect_pysqlite, self).set_isolation_level( + connection, level + ) + + def on_connect(self): + connect = super(SQLiteDialect_pysqlite, self).on_connect() + + def regexp(a, b): + if b is None: + return None + return re.search(a, b) is not None + + def set_regexp(connection): + if hasattr(connection, "dbapi_connection"): + dbapi_connection = connection.dbapi_connection + else: + dbapi_connection = connection + dbapi_connection.create_function( + "regexp", + 2, + regexp, + ) + + fns = [set_regexp] + + if self.isolation_level is not None: + + def iso_level(conn): + self.set_isolation_level(conn, self.isolation_level) + + fns.append(iso_level) + + def connect(conn): + for fn in fns: + fn(conn) + + return connect + + def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exc.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,) + ) + + # theoretically, this list can be augmented, at least as far as + # parameter names accepted by sqlite3/pysqlite, using + # inspect.getfullargspec(). for the moment this seems like overkill + # as these parameters don't change very often, and as always, + # parameters passed to connect_args will always go to the + # sqlite3/pysqlite driver. + pysqlite_args = [ + ("uri", bool), + ("timeout", float), + ("isolation_level", str), + ("detect_types", int), + ("check_same_thread", bool), + ("cached_statements", int), + ] + opts = url.query + pysqlite_opts = {} + for key, type_ in pysqlite_args: + util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts) + + if pysqlite_opts.get("uri", False): + uri_opts = dict(opts) + # here, we are actually separating the parameters that go to + # sqlite3/pysqlite vs. those that go the SQLite URI. What if + # two names conflict? again, this seems to be not the case right + # now, and in the case that new names are added to + # either side which overlap, again the sqlite3/pysqlite parameters + # can be passed through connect_args instead of in the URL. + # If SQLite native URIs add a parameter like "timeout" that + # we already have listed here for the python driver, then we need + # to adjust for that here. + for key, type_ in pysqlite_args: + uri_opts.pop(key, None) + filename = url.database + if uri_opts: + # sorting of keys is for unit test support + filename += "?" + ( + "&".join( + "%s=%s" % (key, uri_opts[key]) + for key in sorted(uri_opts) + ) + ) + else: + filename = url.database or ":memory:" + if filename != ":memory:": + filename = os.path.abspath(filename) + + return ([filename], pysqlite_opts) + + def is_disconnect(self, e, connection, cursor): + return isinstance( + e, self.dbapi.ProgrammingError + ) and "Cannot operate on a closed database." in str(e) + + +dialect = SQLiteDialect_pysqlite diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py new file mode 100644 index 0000000..c7755c8 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -0,0 +1,67 @@ +# sybase/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from . import base # noqa +from . import pyodbc # noqa +from . import pysybase # noqa +from .base import BIGINT +from .base import BINARY +from .base import BIT +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import FLOAT +from .base import IMAGE +from .base import INT +from .base import INTEGER +from .base import MONEY +from .base import NCHAR +from .base import NUMERIC +from .base import NVARCHAR +from .base import SMALLINT +from .base import SMALLMONEY +from .base import TEXT +from .base import TIME +from .base import TINYINT +from .base import UNICHAR +from .base import UNITEXT +from .base import UNIVARCHAR +from .base import VARBINARY +from .base import VARCHAR + + +# default dialect +base.dialect = dialect = pyodbc.dialect + + +__all__ = ( + "CHAR", + "VARCHAR", + "TIME", + "NCHAR", + "NVARCHAR", + "TEXT", + "DATE", + "DATETIME", + "FLOAT", + "NUMERIC", + "BIGINT", + "INT", + "INTEGER", + "SMALLINT", + "BINARY", + "VARBINARY", + "UNITEXT", + "UNICHAR", + "UNIVARCHAR", + "IMAGE", + "BIT", + "MONEY", + "SMALLMONEY", + "TINYINT", + "dialect", +) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py new file mode 100644 index 0000000..83248d1 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -0,0 +1,1100 @@ +# sybase/base.py +# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# get_select_precolumns(), limit_clause() implementation +# copyright (C) 2007 Fisch Asset Management +# AG https://www.fam.ch, with coding by Alexander Houben +# alexander.houben@thor-solutions.ch +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" + +.. dialect:: sybase + :name: Sybase + +.. note:: + + The Sybase dialect within SQLAlchemy **is not currently supported**. + It is not tested within continuous integration and is likely to have + many issues and caveats not currently handled. Consider using the + `external dialect <https://github.com/gordthompson/sqlalchemy-sybase>`_ + instead. + +.. deprecated:: 1.4 The internal Sybase dialect is deprecated and will be + removed in a future version. Use the external dialect. + +""" + +import re + +from sqlalchemy import exc +from sqlalchemy import schema as sa_schema +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from sqlalchemy.engine import default +from sqlalchemy.engine import reflection +from sqlalchemy.sql import compiler +from sqlalchemy.sql import text +from sqlalchemy.types import BIGINT +from sqlalchemy.types import BINARY +from sqlalchemy.types import CHAR +from sqlalchemy.types import DATE +from sqlalchemy.types import DATETIME +from sqlalchemy.types import DECIMAL +from sqlalchemy.types import FLOAT +from sqlalchemy.types import INT # noqa +from sqlalchemy.types import INTEGER +from sqlalchemy.types import NCHAR +from sqlalchemy.types import NUMERIC +from sqlalchemy.types import NVARCHAR +from sqlalchemy.types import REAL +from sqlalchemy.types import SMALLINT +from sqlalchemy.types import TEXT +from sqlalchemy.types import TIME +from sqlalchemy.types import TIMESTAMP +from sqlalchemy.types import Unicode +from sqlalchemy.types import VARBINARY +from sqlalchemy.types import VARCHAR + + +RESERVED_WORDS = set( + [ + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "backup", + "begin", + "between", + "bigint", + "binary", + "bit", + "bottom", + "break", + "by", + "call", + "capability", + "cascade", + "case", + "cast", + "char", + "char_convert", + "character", + "check", + "checkpoint", + "close", + "comment", + "commit", + "connect", + "constraint", + "contains", + "continue", + "convert", + "create", + "cross", + "cube", + "current", + "current_timestamp", + "current_user", + "cursor", + "date", + "dbspace", + "deallocate", + "dec", + "decimal", + "declare", + "default", + "delete", + "deleting", + "desc", + "distinct", + "do", + "double", + "drop", + "dynamic", + "else", + "elseif", + "encrypted", + "end", + "endif", + "escape", + "except", + "exception", + "exec", + "execute", + "existing", + "exists", + "externlogin", + "fetch", + "first", + "float", + "for", + "force", + "foreign", + "forward", + "from", + "full", + "goto", + "grant", + "group", + "having", + "holdlock", + "identified", + "if", + "in", + "index", + "index_lparen", + "inner", + "inout", + "insensitive", + "insert", + "inserting", + "install", + "instead", + "int", + "integer", + "integrated", + "intersect", + "into", + "iq", + "is", + "isolation", + "join", + "key", + "lateral", + "left", + "like", + "lock", + "login", + "long", + "match", + "membership", + "message", + "mode", + "modify", + "natural", + "new", + "no", + "noholdlock", + "not", + "notify", + "null", + "numeric", + "of", + "off", + "on", + "open", + "option", + "options", + "or", + "order", + "others", + "out", + "outer", + "over", + "passthrough", + "precision", + "prepare", + "primary", + "print", + "privileges", + "proc", + "procedure", + "publication", + "raiserror", + "readtext", + "real", + "reference", + "references", + "release", + "remote", + "remove", + "rename", + "reorganize", + "resource", + "restore", + "restrict", + "return", + "revoke", + "right", + "rollback", + "rollup", + "save", + "savepoint", + "scroll", + "select", + "sensitive", + "session", + "set", + "setuser", + "share", + "smallint", + "some", + "sqlcode", + "sqlstate", + "start", + "stop", + "subtrans", + "subtransaction", + "synchronize", + "syntax_error", + "table", + "temporary", + "then", + "time", + "timestamp", + "tinyint", + "to", + "top", + "tran", + "trigger", + "truncate", + "tsequal", + "unbounded", + "union", + "unique", + "unknown", + "unsigned", + "update", + "updating", + "user", + "using", + "validate", + "values", + "varbinary", + "varchar", + "variable", + "varying", + "view", + "wait", + "waitfor", + "when", + "where", + "while", + "window", + "with", + "with_cube", + "with_lparen", + "with_rollup", + "within", + "work", + "writetext", + ] +) + + +class _SybaseUnitypeMixin(object): + """these types appear to return a buffer object.""" + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return str(value) # decode("ucs-2") + else: + return None + + return process + + +class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode): + __visit_name__ = "UNICHAR" + + +class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode): + __visit_name__ = "UNIVARCHAR" + + +class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText): + __visit_name__ = "UNITEXT" + + +class TINYINT(sqltypes.Integer): + __visit_name__ = "TINYINT" + + +class BIT(sqltypes.TypeEngine): + __visit_name__ = "BIT" + + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = "SMALLMONEY" + + +class UNIQUEIDENTIFIER(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" + + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = "IMAGE" + + +class SybaseTypeCompiler(compiler.GenericTypeCompiler): + def visit_large_binary(self, type_, **kw): + return self.visit_IMAGE(type_) + + def visit_boolean(self, type_, **kw): + return self.visit_BIT(type_) + + def visit_unicode(self, type_, **kw): + return self.visit_NVARCHAR(type_) + + def visit_UNICHAR(self, type_, **kw): + return "UNICHAR(%d)" % type_.length + + def visit_UNIVARCHAR(self, type_, **kw): + return "UNIVARCHAR(%d)" % type_.length + + def visit_UNITEXT(self, type_, **kw): + return "UNITEXT" + + def visit_TINYINT(self, type_, **kw): + return "TINYINT" + + def visit_IMAGE(self, type_, **kw): + return "IMAGE" + + def visit_BIT(self, type_, **kw): + return "BIT" + + def visit_MONEY(self, type_, **kw): + return "MONEY" + + def visit_SMALLMONEY(self, type_, **kw): + return "SMALLMONEY" + + def visit_UNIQUEIDENTIFIER(self, type_, **kw): + return "UNIQUEIDENTIFIER" + + +ischema_names = { + "bigint": BIGINT, + "int": INTEGER, + "integer": INTEGER, + "smallint": SMALLINT, + "tinyint": TINYINT, + "unsigned bigint": BIGINT, # TODO: unsigned flags + "unsigned int": INTEGER, # TODO: unsigned flags + "unsigned smallint": SMALLINT, # TODO: unsigned flags + "numeric": NUMERIC, + "decimal": DECIMAL, + "dec": DECIMAL, + "float": FLOAT, + "double": NUMERIC, # TODO + "double precision": NUMERIC, # TODO + "real": REAL, + "smallmoney": SMALLMONEY, + "money": MONEY, + "smalldatetime": DATETIME, + "datetime": DATETIME, + "date": DATE, + "time": TIME, + "char": CHAR, + "character": CHAR, + "varchar": VARCHAR, + "character varying": VARCHAR, + "char varying": VARCHAR, + "unichar": UNICHAR, + "unicode character": UNIVARCHAR, + "nchar": NCHAR, + "national char": NCHAR, + "national character": NCHAR, + "nvarchar": NVARCHAR, + "nchar varying": NVARCHAR, + "national char varying": NVARCHAR, + "national character varying": NVARCHAR, + "text": TEXT, + "unitext": UNITEXT, + "binary": BINARY, + "varbinary": VARBINARY, + "image": IMAGE, + "bit": BIT, + # not in documentation for ASE 15.7 + "long varchar": TEXT, # TODO + "timestamp": TIMESTAMP, + "uniqueidentifier": UNIQUEIDENTIFIER, +} + + +class SybaseInspector(reflection.Inspector): + def __init__(self, conn): + reflection.Inspector.__init__(self, conn) + + def get_table_id(self, table_name, schema=None): + """Return the table id from `table_name` and `schema`.""" + + return self.dialect.get_table_id( + self.bind, table_name, schema, info_cache=self.info_cache + ) + + +class SybaseExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + + def set_ddl_autocommit(self, connection, value): + """Must be implemented by subclasses to accommodate DDL executions. + + "connection" is the raw unwrapped DBAPI connection. "value" + is True or False. when True, the connection should be configured + such that a DDL can take place subsequently. when False, + a DDL has taken place and the connection should be resumed + into non-autocommit mode. + + """ + raise NotImplementedError() + + def pre_exec(self): + if self.isinsert: + tbl = self.compiled.statement.table + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None + + if insert_has_sequence: + self._enable_identity_insert = ( + seq_column.key in self.compiled_parameters[0] + ) + else: + self._enable_identity_insert = False + + if self._enable_identity_insert: + self.cursor.execute( + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(tbl) + ) + + if self.isddl: + # TODO: to enhance this, we can detect "ddl in tran" on the + # database settings. this error message should be improved to + # include a note about that. + if not self.should_autocommit: + raise exc.InvalidRequestError( + "The Sybase dialect only supports " + "DDL in 'autocommit' mode at this time." + ) + + self.root_connection.engine.logger.info( + "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')" + ) + + self.set_ddl_autocommit( + self.root_connection.connection.connection, True + ) + + def post_exec(self): + if self.isddl: + self.set_ddl_autocommit(self.root_connection, False) + + if self._enable_identity_insert: + self.cursor.execute( + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ) + + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT @@identity AS lastrowid") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class SybaseSQLCompiler(compiler.SQLCompiler): + ansi_bind_rules = True + + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + {"doy": "dayofyear", "dow": "weekday", "milliseconds": "millisecond"}, + ) + + def get_from_hint_text(self, table, text): + return text + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += " ROWS LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += " ROWS" + text += " OFFSET " + self.process(select._offset_clause, **kw) + return text + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) + + def visit_now_func(self, fn, **kw): + return "GETDATE()" + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" + # which SQLAlchemy doesn't use + return "" + + def order_by_clause(self, select, **kw): + kw["literal_binds"] = True + order_by = self.process(select._order_by_clause, **kw) + + # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + def delete_table_clause(self, delete_stmt, from_table, extra_froms): + """If we have extra froms make sure we render any alias as hint.""" + ashint = False + if extra_froms: + ashint = True + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, ashint=ashint + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. FROM clause specific to Sybase.""" + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + +class SybaseDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + ) + + if column.table is None: + raise exc.CompileError( + "The Sybase dialect requires Table-bound " + "columns in order to generate DDL" + ) + seq_col = column.table._autoincrement_column + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = ( + isinstance(column.default, sa_schema.Sequence) + and column.default + ) + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + if (start, increment) == (1, 1): + colspec += " IDENTITY" + else: + # TODO: need correct syntax for this + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + return colspec + + def visit_drop_index(self, drop): + index = drop.element + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(index.table.name), + self._prepared_index_name(drop.element, include_schema=False), + ) + + +class SybaseIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + +class SybaseDialect(default.DefaultDialect): + name = "sybase" + supports_unicode_statements = False + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + supports_statement_cache = True + + supports_native_boolean = False + supports_unicode_binds = False + postfetch_lastrowid = True + + colspecs = {} + ischema_names = ischema_names + + type_compiler = SybaseTypeCompiler + statement_compiler = SybaseSQLCompiler + ddl_compiler = SybaseDDLCompiler + preparer = SybaseIdentifierPreparer + inspector = SybaseInspector + + construct_arguments = [] + + def __init__(self, *args, **kwargs): + util.warn_deprecated( + "The Sybase dialect is deprecated and will be removed " + "in a future version. This dialect is superseded by the external " + "dialect https://github.com/gordthompson/sqlalchemy-sybase.", + version="1.4", + ) + super(SybaseDialect, self).__init__(*args, **kwargs) + + def _get_default_schema_name(self, connection): + return connection.scalar( + text("SELECT user_name() as user_name").columns(username=Unicode) + ) + + def initialize(self, connection): + super(SybaseDialect, self).initialize(connection) + if ( + self.server_version_info is not None + and self.server_version_info < (15,) + ): + self.max_identifier_length = 30 + else: + self.max_identifier_length = 255 + + def get_table_id(self, connection, table_name, schema=None, **kw): + """Fetch the id for schema.table_name. + + Several reflection methods require the table id. The idea for using + this method is that it can be fetched one time and cached for + subsequent calls. + + """ + + table_id = None + if schema is None: + schema = self.default_schema_name + + TABLEID_SQL = text( + """ + SELECT o.id AS id + FROM sysobjects o JOIN sysusers u ON o.uid=u.uid + WHERE u.name = :schema_name + AND o.name = :table_name + AND o.type in ('U', 'V') + """ + ) + + if util.py2k: + if isinstance(schema, unicode): # noqa + schema = schema.encode("ascii") + if isinstance(table_name, unicode): # noqa + table_name = table_name.encode("ascii") + result = connection.execute( + TABLEID_SQL, schema_name=schema, table_name=table_name + ) + table_id = result.scalar() + if table_id is None: + raise exc.NoSuchTableError(table_name) + return table_id + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + COLUMN_SQL = text( + """ + SELECT col.name AS name, + t.name AS type, + (col.status & 8) AS nullable, + (col.status & 128) AS autoincrement, + com.text AS 'default', + col.prec AS precision, + col.scale AS scale, + col.length AS length + FROM systypes t, syscolumns col LEFT OUTER JOIN syscomments com ON + col.cdefault = com.id + WHERE col.usertype = t.usertype + AND col.id = :table_id + ORDER BY col.colid + """ + ) + + results = connection.execute(COLUMN_SQL, table_id=table_id) + + columns = [] + for ( + name, + type_, + nullable, + autoincrement, + default_, + precision, + scale, + length, + ) in results: + col_info = self._get_column_info( + name, + type_, + bool(nullable), + bool(autoincrement), + default_, + precision, + scale, + length, + ) + columns.append(col_info) + + return columns + + def _get_column_info( + self, + name, + type_, + nullable, + autoincrement, + default, + precision, + scale, + length, + ): + + coltype = self.ischema_names.get(type_, None) + + kwargs = {} + + if coltype in (NUMERIC, DECIMAL): + args = (precision, scale) + elif coltype == FLOAT: + args = (precision,) + elif coltype in (CHAR, VARCHAR, UNICHAR, UNIVARCHAR, NCHAR, NVARCHAR): + args = (length,) + else: + args = () + + if coltype: + coltype = coltype(*args, **kwargs) + # is this necessary + # if is_array: + # coltype = ARRAY(coltype) + else: + util.warn( + "Did not recognize type '%s' of column '%s'" % (type_, name) + ) + coltype = sqltypes.NULLTYPE + + if default: + default = default.replace("DEFAULT", "").strip() + default = re.sub("^'(.*)'$", lambda m: m.group(1), default) + else: + default = None + + column_info = dict( + name=name, + type=coltype, + nullable=nullable, + default=default, + autoincrement=autoincrement, + ) + return column_info + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + table_cache = {} + column_cache = {} + foreign_keys = [] + + table_cache[table_id] = {"name": table_name, "schema": schema} + + COLUMN_SQL = text( + """ + SELECT c.colid AS id, c.name AS name + FROM syscolumns c + WHERE c.id = :table_id + """ + ) + + results = connection.execute(COLUMN_SQL, table_id=table_id) + columns = {} + for col in results: + columns[col["id"]] = col["name"] + column_cache[table_id] = columns + + REFCONSTRAINT_SQL = text( + """ + SELECT o.name AS name, r.reftabid AS reftable_id, + r.keycnt AS 'count', + r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3, + r.fokey4 AS fokey4, r.fokey5 AS fokey5, r.fokey6 AS fokey6, + r.fokey7 AS fokey7, r.fokey1 AS fokey8, r.fokey9 AS fokey9, + r.fokey10 AS fokey10, r.fokey11 AS fokey11, r.fokey12 AS fokey12, + r.fokey13 AS fokey13, r.fokey14 AS fokey14, r.fokey15 AS fokey15, + r.fokey16 AS fokey16, + r.refkey1 AS refkey1, r.refkey2 AS refkey2, r.refkey3 AS refkey3, + r.refkey4 AS refkey4, r.refkey5 AS refkey5, r.refkey6 AS refkey6, + r.refkey7 AS refkey7, r.refkey1 AS refkey8, r.refkey9 AS refkey9, + r.refkey10 AS refkey10, r.refkey11 AS refkey11, + r.refkey12 AS refkey12, r.refkey13 AS refkey13, + r.refkey14 AS refkey14, r.refkey15 AS refkey15, + r.refkey16 AS refkey16 + FROM sysreferences r JOIN sysobjects o on r.tableid = o.id + WHERE r.tableid = :table_id + """ + ) + referential_constraints = connection.execute( + REFCONSTRAINT_SQL, table_id=table_id + ).fetchall() + + REFTABLE_SQL = text( + """ + SELECT o.name AS name, u.name AS 'schema' + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE o.id = :table_id + """ + ) + + for r in referential_constraints: + reftable_id = r["reftable_id"] + + if reftable_id not in table_cache: + c = connection.execute(REFTABLE_SQL, table_id=reftable_id) + reftable = c.fetchone() + c.close() + table_info = {"name": reftable["name"], "schema": None} + if ( + schema is not None + or reftable["schema"] != self.default_schema_name + ): + table_info["schema"] = reftable["schema"] + + table_cache[reftable_id] = table_info + results = connection.execute(COLUMN_SQL, table_id=reftable_id) + reftable_columns = {} + for col in results: + reftable_columns[col["id"]] = col["name"] + column_cache[reftable_id] = reftable_columns + + reftable = table_cache[reftable_id] + reftable_columns = column_cache[reftable_id] + + constrained_columns = [] + referred_columns = [] + for i in range(1, r["count"] + 1): + constrained_columns.append(columns[r["fokey%i" % i]]) + referred_columns.append(reftable_columns[r["refkey%i" % i]]) + + fk_info = { + "constrained_columns": constrained_columns, + "referred_schema": reftable["schema"], + "referred_table": reftable["name"], + "referred_columns": referred_columns, + "name": r["name"], + } + + foreign_keys.append(fk_info) + + return foreign_keys + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + INDEX_SQL = text( + """ + SELECT object_name(i.id) AS table_name, + i.keycnt AS 'count', + i.name AS name, + (i.status & 0x2) AS 'unique', + index_col(object_name(i.id), i.indid, 1) AS col_1, + index_col(object_name(i.id), i.indid, 2) AS col_2, + index_col(object_name(i.id), i.indid, 3) AS col_3, + index_col(object_name(i.id), i.indid, 4) AS col_4, + index_col(object_name(i.id), i.indid, 5) AS col_5, + index_col(object_name(i.id), i.indid, 6) AS col_6, + index_col(object_name(i.id), i.indid, 7) AS col_7, + index_col(object_name(i.id), i.indid, 8) AS col_8, + index_col(object_name(i.id), i.indid, 9) AS col_9, + index_col(object_name(i.id), i.indid, 10) AS col_10, + index_col(object_name(i.id), i.indid, 11) AS col_11, + index_col(object_name(i.id), i.indid, 12) AS col_12, + index_col(object_name(i.id), i.indid, 13) AS col_13, + index_col(object_name(i.id), i.indid, 14) AS col_14, + index_col(object_name(i.id), i.indid, 15) AS col_15, + index_col(object_name(i.id), i.indid, 16) AS col_16 + FROM sysindexes i, sysobjects o + WHERE o.id = i.id + AND o.id = :table_id + AND (i.status & 2048) = 0 + AND i.indid BETWEEN 1 AND 254 + """ + ) + + results = connection.execute(INDEX_SQL, table_id=table_id) + indexes = [] + for r in results: + column_names = [] + for i in range(1, r["count"]): + column_names.append(r["col_%i" % (i,)]) + index_info = { + "name": r["name"], + "unique": bool(r["unique"]), + "column_names": column_names, + } + indexes.append(index_info) + + return indexes + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) + + PK_SQL = text( + """ + SELECT object_name(i.id) AS table_name, + i.keycnt AS 'count', + i.name AS name, + index_col(object_name(i.id), i.indid, 1) AS pk_1, + index_col(object_name(i.id), i.indid, 2) AS pk_2, + index_col(object_name(i.id), i.indid, 3) AS pk_3, + index_col(object_name(i.id), i.indid, 4) AS pk_4, + index_col(object_name(i.id), i.indid, 5) AS pk_5, + index_col(object_name(i.id), i.indid, 6) AS pk_6, + index_col(object_name(i.id), i.indid, 7) AS pk_7, + index_col(object_name(i.id), i.indid, 8) AS pk_8, + index_col(object_name(i.id), i.indid, 9) AS pk_9, + index_col(object_name(i.id), i.indid, 10) AS pk_10, + index_col(object_name(i.id), i.indid, 11) AS pk_11, + index_col(object_name(i.id), i.indid, 12) AS pk_12, + index_col(object_name(i.id), i.indid, 13) AS pk_13, + index_col(object_name(i.id), i.indid, 14) AS pk_14, + index_col(object_name(i.id), i.indid, 15) AS pk_15, + index_col(object_name(i.id), i.indid, 16) AS pk_16 + FROM sysindexes i, sysobjects o + WHERE o.id = i.id + AND o.id = :table_id + AND (i.status & 2048) = 2048 + AND i.indid BETWEEN 1 AND 254 + """ + ) + + results = connection.execute(PK_SQL, table_id=table_id) + pks = results.fetchone() + results.close() + + constrained_columns = [] + if pks: + for i in range(1, pks["count"] + 1): + constrained_columns.append(pks["pk_%i" % (i,)]) + return { + "constrained_columns": constrained_columns, + "name": pks["name"], + } + else: + return {"constrained_columns": [], "name": None} + + @reflection.cache + def get_schema_names(self, connection, **kw): + + SCHEMA_SQL = text("SELECT u.name AS name FROM sysusers u") + + schemas = connection.execute(SCHEMA_SQL) + + return [s["name"] for s in schemas] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + TABLE_SQL = text( + """ + SELECT o.name AS name + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE u.name = :schema_name + AND o.type = 'U' + """ + ) + + if util.py2k: + if isinstance(schema, unicode): # noqa + schema = schema.encode("ascii") + + tables = connection.execute(TABLE_SQL, schema_name=schema) + + return [t["name"] for t in tables] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + VIEW_DEF_SQL = text( + """ + SELECT c.text + FROM syscomments c JOIN sysobjects o ON c.id = o.id + WHERE o.name = :view_name + AND o.type = 'V' + """ + ) + + if util.py2k: + if isinstance(view_name, unicode): # noqa + view_name = view_name.encode("ascii") + + view = connection.execute(VIEW_DEF_SQL, view_name=view_name) + + return view.scalar() + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + + VIEW_SQL = text( + """ + SELECT o.name AS name + FROM sysobjects o JOIN sysusers u ON o.uid = u.uid + WHERE u.name = :schema_name + AND o.type = 'V' + """ + ) + + if util.py2k: + if isinstance(schema, unicode): # noqa + schema = schema.encode("ascii") + views = connection.execute(VIEW_SQL, schema_name=schema) + + return [v["name"] for v in views] + + def has_table(self, connection, table_name, schema=None): + self._ensure_has_table_connection(connection) + + try: + self.get_table_id(connection, table_name, schema) + except exc.NoSuchTableError: + return False + else: + return True diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py new file mode 100644 index 0000000..fe5a614 --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -0,0 +1,34 @@ +# sybase/mxodbc.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" + +.. dialect:: sybase+mxodbc + :name: mxODBC + :dbapi: mxodbc + :connectstring: sybase+mxodbc://<username>:<password>@<dsnname> + :url: https://www.egenix.com/ + +.. note:: + + This dialect is a stub only and is likely non functional at this time. + +""" +from sqlalchemy.connectors.mxodbc import MxODBCConnector +from sqlalchemy.dialects.sybase.base import SybaseDialect +from sqlalchemy.dialects.sybase.base import SybaseExecutionContext + + +class SybaseExecutionContext_mxodbc(SybaseExecutionContext): + pass + + +class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_mxodbc + supports_statement_cache = True + + +dialect = SybaseDialect_mxodbc diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py new file mode 100644 index 0000000..f408e8f --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -0,0 +1,89 @@ +# sybase/pyodbc.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sybase+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: sybase+pyodbc://<username>:<password>@<dsnname>[/<database>] + :url: https://pypi.org/project/pyodbc/ + +Unicode Support +--------------- + +The pyodbc driver currently supports usage of these Sybase types with +Unicode or multibyte strings:: + + CHAR + NCHAR + NVARCHAR + TEXT + VARCHAR + +Currently *not* supported are:: + + UNICHAR + UNITEXT + UNIVARCHAR + +""" # noqa + +import decimal + +from sqlalchemy import processors +from sqlalchemy import types as sqltypes +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy.dialects.sybase.base import SybaseDialect +from sqlalchemy.dialects.sybase.base import SybaseExecutionContext + + +class _SybNumeric_pyodbc(sqltypes.Numeric): + """Turns Decimals with adjusted() < -6 into floats. + + It's not yet known how to get decimals with many + significant digits or very large adjusted() into Sybase + via pyodbc. + + """ + + def bind_processor(self, dialect): + super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect) + + def process(value): + if self.asdecimal and isinstance(value, decimal.Decimal): + + if value.adjusted() < -6: + return processors.to_float(value) + + if super_process: + return super_process(value) + else: + return value + + return process + + +class SybaseExecutionContext_pyodbc(SybaseExecutionContext): + def set_ddl_autocommit(self, connection, value): + if value: + connection.autocommit = True + else: + connection.autocommit = False + + +class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect): + execution_ctx_cls = SybaseExecutionContext_pyodbc + supports_statement_cache = True + + colspecs = {sqltypes.Numeric: _SybNumeric_pyodbc} + + @classmethod + def dbapi(cls): + return PyODBCConnector.dbapi() + + +dialect = SybaseDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py new file mode 100644 index 0000000..4c96aac --- /dev/null +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -0,0 +1,106 @@ +# sybase/pysybase.py +# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: sybase+pysybase + :name: Python-Sybase + :dbapi: Sybase + :connectstring: sybase+pysybase://<username>:<password>@<dsn>/[database name] + :url: https://python-sybase.sourceforge.net/ + +Unicode Support +--------------- + +The python-sybase driver does not appear to support non-ASCII strings of any +kind at this time. + +""" # noqa + +from sqlalchemy import processors +from sqlalchemy import types as sqltypes +from sqlalchemy.dialects.sybase.base import SybaseDialect +from sqlalchemy.dialects.sybase.base import SybaseExecutionContext +from sqlalchemy.dialects.sybase.base import SybaseSQLCompiler + + +class _SybNumeric(sqltypes.Numeric): + def result_processor(self, dialect, type_): + if not self.asdecimal: + return processors.to_float + else: + return sqltypes.Numeric.result_processor(self, dialect, type_) + + +class SybaseExecutionContext_pysybase(SybaseExecutionContext): + def set_ddl_autocommit(self, dbapi_connection, value): + if value: + # call commit() on the Sybase connection directly, + # to avoid any side effects of calling a Connection + # transactional method inside of pre_exec() + dbapi_connection.commit() + + def pre_exec(self): + SybaseExecutionContext.pre_exec(self) + + for param in self.parameters: + for key in list(param): + param["@" + key] = param[key] + del param[key] + + +class SybaseSQLCompiler_pysybase(SybaseSQLCompiler): + def bindparam_string(self, name, **kw): + return "@" + name + + +class SybaseDialect_pysybase(SybaseDialect): + driver = "pysybase" + execution_ctx_cls = SybaseExecutionContext_pysybase + statement_compiler = SybaseSQLCompiler_pysybase + + supports_statement_cache = True + + colspecs = {sqltypes.Numeric: _SybNumeric, sqltypes.Float: sqltypes.Float} + + @classmethod + def dbapi(cls): + import Sybase + + return Sybase + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user", password="passwd") + + return ([opts.pop("host")], opts) + + def do_executemany(self, cursor, statement, parameters, context=None): + # calling python-sybase executemany yields: + # TypeError: string too long for buffer + for param in parameters: + cursor.execute(statement, param) + + def _get_server_version_info(self, connection): + vers = connection.exec_driver_sql("select @@version_number").scalar() + # i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0), + # (12, 5, 0, 0) + return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10) + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): + msg = str(e) + return ( + "Unable to complete network request to host" in msg + or "Invalid connection state" in msg + or "Invalid cursor state" in msg + ) + else: + return False + + +dialect = SybaseDialect_pysybase |