diff options
Diffstat (limited to 'lib/sqlalchemy/dialects/firebird')
-rw-r--r-- | lib/sqlalchemy/dialects/firebird/__init__.py | 41 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/firebird/base.py | 989 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/firebird/fdb.py | 112 | ||||
-rw-r--r-- | lib/sqlalchemy/dialects/firebird/kinterbasdb.py | 202 |
4 files changed, 1344 insertions, 0 deletions
diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py new file mode 100644 index 0000000..a34eecf --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -0,0 +1,41 @@ +# firebird/__init__.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from sqlalchemy.dialects.firebird.base import BIGINT +from sqlalchemy.dialects.firebird.base import BLOB +from sqlalchemy.dialects.firebird.base import CHAR +from sqlalchemy.dialects.firebird.base import DATE +from sqlalchemy.dialects.firebird.base import FLOAT +from sqlalchemy.dialects.firebird.base import NUMERIC +from sqlalchemy.dialects.firebird.base import SMALLINT +from sqlalchemy.dialects.firebird.base import TEXT +from sqlalchemy.dialects.firebird.base import TIME +from sqlalchemy.dialects.firebird.base import TIMESTAMP +from sqlalchemy.dialects.firebird.base import VARCHAR +from . import base # noqa +from . import fdb # noqa +from . import kinterbasdb # noqa + + +base.dialect = dialect = fdb.dialect + +__all__ = ( + "SMALLINT", + "BIGINT", + "FLOAT", + "FLOAT", + "DATE", + "TIME", + "TEXT", + "NUMERIC", + "FLOAT", + "TIMESTAMP", + "VARCHAR", + "CHAR", + "BLOB", + "dialect", +) diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py new file mode 100644 index 0000000..e2698b1 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -0,0 +1,989 @@ +# firebird/base.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r""" + +.. dialect:: firebird + :name: Firebird + +.. note:: + + The Firebird dialect within SQLAlchemy **is not currently supported**. + It is not tested within continuous integration and is likely to have + many issues and caveats not currently handled. Consider using the + `external dialect <https://github.com/pauldex/sqlalchemy-firebird>`_ + instead. + +.. deprecated:: 1.4 The internal Firebird dialect is deprecated and will be + removed in a future version. Use the external dialect. + +Firebird Dialects +----------------- + +Firebird offers two distinct dialects_ (not to be confused with a +SQLAlchemy ``Dialect``): + +dialect 1 + This is the old syntax and behaviour, inherited from Interbase pre-6.0. + +dialect 3 + This is the newer and supported syntax, introduced in Interbase 6.0. + +The SQLAlchemy Firebird dialect detects these versions and +adjusts its representation of SQL accordingly. However, +support for dialect 1 is not well tested and probably has +incompatibilities. + +Locking Behavior +---------------- + +Firebird locks tables aggressively. For this reason, a DROP TABLE may +hang until other transactions are released. SQLAlchemy does its best +to release transactions as quickly as possible. The most common cause +of hanging transactions is a non-fully consumed result set, i.e.:: + + result = engine.execute(text("select * from table")) + row = result.fetchone() + return + +Where above, the ``CursorResult`` has not been fully consumed. The +connection will be returned to the pool and the transactional state +rolled back once the Python garbage collector reclaims the objects +which hold onto the connection, which often occurs asynchronously. +The above use case can be alleviated by calling ``first()`` on the +``CursorResult`` which will fetch the first row and immediately close +all remaining cursor/connection resources. + +RETURNING support +----------------- + +Firebird 2.0 supports returning a result set from inserts, and 2.1 +extends that to deletes and updates. This is generically exposed by +the SQLAlchemy ``returning()`` method, such as:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\ + values(name='foo') + print(result.fetchall()) + + # UPDATE..RETURNING + raises = empl.update().returning(empl.c.id, empl.c.salary).\ + where(empl.c.sales>100).\ + values(dict(salary=empl.c.salary * 1.1)) + print(raises.fetchall()) + + +.. _dialects: https://mc-computing.com/Databases/Firebird/SQL_Dialect.html +""" + +import datetime + +from sqlalchemy import exc +from sqlalchemy import sql +from sqlalchemy import types as sqltypes +from sqlalchemy import util +from sqlalchemy.engine import default +from sqlalchemy.engine import reflection +from sqlalchemy.sql import compiler +from sqlalchemy.sql import expression +from sqlalchemy.types import BIGINT +from sqlalchemy.types import BLOB +from sqlalchemy.types import DATE +from sqlalchemy.types import FLOAT +from sqlalchemy.types import INTEGER +from sqlalchemy.types import Integer +from sqlalchemy.types import NUMERIC +from sqlalchemy.types import SMALLINT +from sqlalchemy.types import TEXT +from sqlalchemy.types import TIME +from sqlalchemy.types import TIMESTAMP + + +RESERVED_WORDS = set( + [ + "active", + "add", + "admin", + "after", + "all", + "alter", + "and", + "any", + "as", + "asc", + "ascending", + "at", + "auto", + "avg", + "before", + "begin", + "between", + "bigint", + "bit_length", + "blob", + "both", + "by", + "case", + "cast", + "char", + "character", + "character_length", + "char_length", + "check", + "close", + "collate", + "column", + "commit", + "committed", + "computed", + "conditional", + "connect", + "constraint", + "containing", + "count", + "create", + "cross", + "cstring", + "current", + "current_connection", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_transaction", + "current_user", + "cursor", + "database", + "date", + "day", + "dec", + "decimal", + "declare", + "default", + "delete", + "desc", + "descending", + "disconnect", + "distinct", + "do", + "domain", + "double", + "drop", + "else", + "end", + "entry_point", + "escape", + "exception", + "execute", + "exists", + "exit", + "external", + "extract", + "fetch", + "file", + "filter", + "float", + "for", + "foreign", + "from", + "full", + "function", + "gdscode", + "generator", + "gen_id", + "global", + "grant", + "group", + "having", + "hour", + "if", + "in", + "inactive", + "index", + "inner", + "input_type", + "insensitive", + "insert", + "int", + "integer", + "into", + "is", + "isolation", + "join", + "key", + "leading", + "left", + "length", + "level", + "like", + "long", + "lower", + "manual", + "max", + "maximum_segment", + "merge", + "min", + "minute", + "module_name", + "month", + "names", + "national", + "natural", + "nchar", + "no", + "not", + "null", + "numeric", + "octet_length", + "of", + "on", + "only", + "open", + "option", + "or", + "order", + "outer", + "output_type", + "overflow", + "page", + "pages", + "page_size", + "parameter", + "password", + "plan", + "position", + "post_event", + "precision", + "primary", + "privileges", + "procedure", + "protected", + "rdb$db_key", + "read", + "real", + "record_version", + "recreate", + "recursive", + "references", + "release", + "reserv", + "reserving", + "retain", + "returning_values", + "returns", + "revoke", + "right", + "rollback", + "rows", + "row_count", + "savepoint", + "schema", + "second", + "segment", + "select", + "sensitive", + "set", + "shadow", + "shared", + "singular", + "size", + "smallint", + "snapshot", + "some", + "sort", + "sqlcode", + "stability", + "start", + "starting", + "starts", + "statistics", + "sub_type", + "sum", + "suspend", + "table", + "then", + "time", + "timestamp", + "to", + "trailing", + "transaction", + "trigger", + "trim", + "uncommitted", + "union", + "unique", + "update", + "upper", + "user", + "using", + "value", + "values", + "varchar", + "variable", + "varying", + "view", + "wait", + "when", + "where", + "while", + "with", + "work", + "write", + "year", + ] +) + + +class _StringType(sqltypes.String): + """Base for Firebird string types.""" + + def __init__(self, charset=None, **kw): + self.charset = charset + super(_StringType, self).__init__(**kw) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """Firebird VARCHAR type""" + + __visit_name__ = "VARCHAR" + + def __init__(self, length=None, **kwargs): + super(VARCHAR, self).__init__(length=length, **kwargs) + + +class CHAR(_StringType, sqltypes.CHAR): + """Firebird CHAR type""" + + __visit_name__ = "CHAR" + + def __init__(self, length=None, **kwargs): + super(CHAR, self).__init__(length=length, **kwargs) + + +class _FBDateTime(sqltypes.DateTime): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + + return process + + +colspecs = {sqltypes.DateTime: _FBDateTime} + +ischema_names = { + "SHORT": SMALLINT, + "LONG": INTEGER, + "QUAD": FLOAT, + "FLOAT": FLOAT, + "DATE": DATE, + "TIME": TIME, + "TEXT": TEXT, + "INT64": BIGINT, + "DOUBLE": FLOAT, + "TIMESTAMP": TIMESTAMP, + "VARYING": VARCHAR, + "CSTRING": CHAR, + "BLOB": BLOB, +} + + +# TODO: date conversion types (should be implemented as _FBDateTime, +# _FBDate, etc. as bind/result functionality is required) + + +class FBTypeCompiler(compiler.GenericTypeCompiler): + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + def visit_TEXT(self, type_, **kw): + return "BLOB SUB_TYPE 1" + + def visit_BLOB(self, type_, **kw): + return "BLOB SUB_TYPE 0" + + def _extend_string(self, type_, basic): + charset = getattr(type_, "charset", None) + if charset is None: + return basic + else: + return "%s CHARACTER SET %s" % (basic, charset) + + def visit_CHAR(self, type_, **kw): + basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw) + return self._extend_string(type_, basic) + + def visit_VARCHAR(self, type_, **kw): + if not type_.length: + raise exc.CompileError( + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) + basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw) + return self._extend_string(type_, basic) + + +class FBCompiler(sql.compiler.SQLCompiler): + """Firebird specific idiosyncrasies""" + + ansi_bind_rules = True + + # def visit_contains_op_binary(self, binary, operator, **kw): + # cant use CONTAINING b.c. it's case insensitive. + + # def visit_not_contains_op_binary(self, binary, operator, **kw): + # cant use NOT CONTAINING b.c. it's case insensitive. + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_startswith_op_binary(self, binary, operator, **kw): + return "%s STARTING WITH %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + + def visit_not_startswith_op_binary(self, binary, operator, **kw): + return "%s NOT STARTING WITH %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + + def visit_mod_binary(self, binary, operator, **kw): + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_alias(self, alias, asfrom=False, **kwargs): + if self.dialect._version_two: + return super(FBCompiler, self).visit_alias( + alias, asfrom=asfrom, **kwargs + ) + else: + # Override to not use the AS keyword which FB 1.5 does not like + if asfrom: + alias_name = ( + isinstance(alias.name, expression._truncated_label) + and self._truncated_identifier("alias", alias.name) + or alias.name + ) + + return ( + self.process(alias.element, asfrom=asfrom, **kwargs) + + " " + + self.preparer.format_alias(alias, alias_name) + ) + else: + return self.process(alias.element, **kwargs) + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0]) + start = self.process(func.clauses.clauses[1]) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2]) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def visit_length_func(self, function, **kw): + if self.dialect._version_two: + return "char_length" + self.function_argspec(function) + else: + return "strlen" + self.function_argspec(function) + + visit_char_length_func = visit_length_func + + def function_argspec(self, func, **kw): + # TODO: this probably will need to be + # narrowed to a fixed list, some no-arg functions + # may require parens - see similar example in the oracle + # dialect + if func.clauses is not None and len(func.clauses): + return self.process(func.clause_expr, **kw) + else: + return "" + + def default_from(self): + return " FROM rdb$database" + + def visit_sequence(self, seq, **kw): + return "gen_id(%s, 1)" % self.preparer.format_sequence(seq) + + def get_select_precolumns(self, select, **kw): + """Called when building a ``SELECT`` statement, position is just + before column list Firebird puts the limit and offset right + after the ``SELECT``... + """ + + result = "" + if select._limit_clause is not None: + result += "FIRST %s " % self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + result += "SKIP %s " % self.process(select._offset_clause, **kw) + result += super(FBCompiler, self).get_select_precolumns(select, **kw) + return result + + def limit_clause(self, select, **kw): + """Already taken care of in the `get_select_precolumns` method.""" + + return "" + + def returning_clause(self, stmt, returning_cols): + columns = [ + self._label_returning_column(stmt, c) + for c in expression._select_iterables(returning_cols) + ] + + return "RETURNING " + ", ".join(columns) + + +class FBDDLCompiler(sql.compiler.DDLCompiler): + """Firebird syntactic idiosyncrasies""" + + def visit_create_sequence(self, create): + """Generate a ``CREATE GENERATOR`` statement for the sequence.""" + + # no syntax for these + # https://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html + if create.element.start is not None: + raise NotImplementedError( + "Firebird SEQUENCE doesn't support START WITH" + ) + if create.element.increment is not None: + raise NotImplementedError( + "Firebird SEQUENCE doesn't support INCREMENT BY" + ) + + if self.dialect._version_two: + return "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) + else: + return "CREATE GENERATOR %s" % self.preparer.format_sequence( + create.element + ) + + def visit_drop_sequence(self, drop): + """Generate a ``DROP GENERATOR`` statement for the sequence.""" + + if self.dialect._version_two: + return "DROP SEQUENCE %s" % self.preparer.format_sequence( + drop.element + ) + else: + return "DROP GENERATOR %s" % self.preparer.format_sequence( + drop.element + ) + + def visit_computed_column(self, generated): + if generated.persisted is not None: + raise exc.CompileError( + "Firebird computed columns do not support a persistence " + "method setting; set the 'persisted' flag to None for " + "Firebird support." + ) + return "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + + +class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): + """Install Firebird specific reserved words.""" + + reserved_words = RESERVED_WORDS + illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union( + ["_"] + ) + + def __init__(self, dialect): + super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) + + +class FBExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + """Get the next value from the sequence using ``gen_id()``.""" + + return self._execute_scalar( + "SELECT gen_id(%s, 1) FROM rdb$database" + % self.identifier_preparer.format_sequence(seq), + type_, + ) + + +class FBDialect(default.DefaultDialect): + """Firebird dialect""" + + name = "firebird" + supports_statement_cache = True + + max_identifier_length = 31 + + supports_sequences = True + sequences_optional = False + supports_default_values = True + postfetch_lastrowid = False + + supports_native_boolean = False + + requires_name_normalize = True + supports_empty_insert = False + + statement_compiler = FBCompiler + ddl_compiler = FBDDLCompiler + preparer = FBIdentifierPreparer + type_compiler = FBTypeCompiler + execution_ctx_cls = FBExecutionContext + + colspecs = colspecs + ischema_names = ischema_names + + construct_arguments = [] + + # defaults to dialect ver. 3, + # will be autodetected off upon + # first connect + _version_two = True + + def __init__(self, *args, **kwargs): + util.warn_deprecated( + "The firebird dialect is deprecated and will be removed " + "in a future version. This dialect is superseded by the external " + "dialect https://github.com/pauldex/sqlalchemy-firebird.", + version="1.4", + ) + super(FBDialect, self).__init__(*args, **kwargs) + + def initialize(self, connection): + super(FBDialect, self).initialize(connection) + self._version_two = ( + "firebird" in self.server_version_info + and self.server_version_info >= (2,) + ) or ( + "interbase" in self.server_version_info + and self.server_version_info >= (6,) + ) + + if not self._version_two: + # TODO: whatever other pre < 2.0 stuff goes here + self.ischema_names = ischema_names.copy() + self.ischema_names["TIMESTAMP"] = sqltypes.DATE + self.colspecs = {sqltypes.DateTime: sqltypes.DATE} + + self.implicit_returning = self._version_two and self.__dict__.get( + "implicit_returning", True + ) + + def has_table(self, connection, table_name, schema=None): + """Return ``True`` if the given table exists, ignoring + the `schema`.""" + self._ensure_has_table_connection(connection) + + tblqry = """ + SELECT 1 AS has_table FROM rdb$database + WHERE EXISTS (SELECT rdb$relation_name + FROM rdb$relations + WHERE rdb$relation_name=?) + """ + c = connection.exec_driver_sql( + tblqry, [self.denormalize_name(table_name)] + ) + return c.first() is not None + + def has_sequence(self, connection, sequence_name, schema=None): + """Return ``True`` if the given sequence (generator) exists.""" + + genqry = """ + SELECT 1 AS has_sequence FROM rdb$database + WHERE EXISTS (SELECT rdb$generator_name + FROM rdb$generators + WHERE rdb$generator_name=?) + """ + c = connection.exec_driver_sql( + genqry, [self.denormalize_name(sequence_name)] + ) + return c.first() is not None + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + # there are two queries commonly mentioned for this. + # this one, using view_blr, is at the Firebird FAQ among other places: + # https://www.firebirdfaq.org/faq174/ + s = """ + select rdb$relation_name + from rdb$relations + where rdb$view_blr is null + and (rdb$system_flag is null or rdb$system_flag = 0); + """ + + # the other query is this one. It's not clear if there's really + # any difference between these two. This link: + # https://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8 + # states them as interchangeable. Some discussion at [ticket:2898] + # SELECT DISTINCT rdb$relation_name + # FROM rdb$relation_fields + # WHERE rdb$system_flag=0 AND rdb$view_context IS NULL + + return [ + self.normalize_name(row[0]) + for row in connection.exec_driver_sql(s) + ] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + # see https://www.firebirdfaq.org/faq174/ + s = """ + select rdb$relation_name + from rdb$relations + where rdb$view_blr is not null + and (rdb$system_flag is null or rdb$system_flag = 0); + """ + return [ + self.normalize_name(row[0]) + for row in connection.exec_driver_sql(s) + ] + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + qry = """ + SELECT rdb$view_source AS view_source + FROM rdb$relations + WHERE rdb$relation_name=? + """ + rp = connection.exec_driver_sql( + qry, [self.denormalize_name(view_name)] + ) + row = rp.first() + if row: + return row["view_source"] + else: + return None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + # Query to extract the PK/FK constrained fields of the given table + keyqry = """ + SELECT se.rdb$field_name AS fname + FROM rdb$relation_constraints rc + JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + """ + tablename = self.denormalize_name(table_name) + # get primary key fields + c = connection.exec_driver_sql(keyqry, ["PRIMARY KEY", tablename]) + pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()] + return {"constrained_columns": pkfields, "name": None} + + @reflection.cache + def get_column_sequence( + self, connection, table_name, column_name, schema=None, **kw + ): + tablename = self.denormalize_name(table_name) + colname = self.denormalize_name(column_name) + # Heuristic-query to determine the generator associated to a PK field + genqry = """ + SELECT trigdep.rdb$depended_on_name AS fgenerator + FROM rdb$dependencies tabdep + JOIN rdb$dependencies trigdep + ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name + AND trigdep.rdb$depended_on_type=14 + AND trigdep.rdb$dependent_type=2 + JOIN rdb$triggers trig ON + trig.rdb$trigger_name=tabdep.rdb$dependent_name + WHERE tabdep.rdb$depended_on_name=? + AND tabdep.rdb$depended_on_type=0 + AND trig.rdb$trigger_type=1 + AND tabdep.rdb$field_name=? + AND (SELECT count(*) + FROM rdb$dependencies trigdep2 + WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2 + """ + genr = connection.exec_driver_sql(genqry, [tablename, colname]).first() + if genr is not None: + return dict(name=self.normalize_name(genr["fgenerator"])) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + # Query to extract the details of all the fields of the given table + tblqry = """ + SELECT r.rdb$field_name AS fname, + r.rdb$null_flag AS null_flag, + t.rdb$type_name AS ftype, + f.rdb$field_sub_type AS stype, + f.rdb$field_length/ + COALESCE(cs.rdb$bytes_per_character,1) AS flen, + f.rdb$field_precision AS fprec, + f.rdb$field_scale AS fscale, + COALESCE(r.rdb$default_source, + f.rdb$default_source) AS fdefault + FROM rdb$relation_fields r + JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name + JOIN rdb$types t + ON t.rdb$type=f.rdb$field_type AND + t.rdb$field_name='RDB$FIELD_TYPE' + LEFT JOIN rdb$character_sets cs ON + f.rdb$character_set_id=cs.rdb$character_set_id + WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=? + ORDER BY r.rdb$field_position + """ + # get the PK, used to determine the eventual associated sequence + pk_constraint = self.get_pk_constraint(connection, table_name) + pkey_cols = pk_constraint["constrained_columns"] + + tablename = self.denormalize_name(table_name) + # get all of the fields for this table + c = connection.exec_driver_sql(tblqry, [tablename]) + cols = [] + while True: + row = c.fetchone() + if row is None: + break + name = self.normalize_name(row["fname"]) + orig_colname = row["fname"] + + # get the data type + colspec = row["ftype"].rstrip() + coltype = self.ischema_names.get(colspec) + if coltype is None: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (colspec, name) + ) + coltype = sqltypes.NULLTYPE + elif issubclass(coltype, Integer) and row["fprec"] != 0: + coltype = NUMERIC( + precision=row["fprec"], scale=row["fscale"] * -1 + ) + elif colspec in ("VARYING", "CSTRING"): + coltype = coltype(row["flen"]) + elif colspec == "TEXT": + coltype = TEXT(row["flen"]) + elif colspec == "BLOB": + if row["stype"] == 1: + coltype = TEXT() + else: + coltype = BLOB() + else: + coltype = coltype() + + # does it have a default value? + defvalue = None + if row["fdefault"] is not None: + # the value comes down as "DEFAULT 'value'": there may be + # more than one whitespace around the "DEFAULT" keyword + # and it may also be lower case + # (see also https://tracker.firebirdsql.org/browse/CORE-356) + defexpr = row["fdefault"].lstrip() + assert defexpr[:8].rstrip().upper() == "DEFAULT", ( + "Unrecognized default value: %s" % defexpr + ) + defvalue = defexpr[8:].strip() + if defvalue == "NULL": + # Redundant + defvalue = None + col_d = { + "name": name, + "type": coltype, + "nullable": not bool(row["null_flag"]), + "default": defvalue, + "autoincrement": "auto", + } + + if orig_colname.lower() == orig_colname: + col_d["quote"] = True + + # if the PK is a single field, try to see if its linked to + # a sequence thru a trigger + if len(pkey_cols) == 1 and name == pkey_cols[0]: + seq_d = self.get_column_sequence(connection, tablename, name) + if seq_d is not None: + col_d["sequence"] = seq_d + + cols.append(col_d) + return cols + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # Query to extract the details of each UK/FK of the given table + fkqry = """ + SELECT rc.rdb$constraint_name AS cname, + cse.rdb$field_name AS fname, + ix2.rdb$relation_name AS targetrname, + se.rdb$field_name AS targetfname + FROM rdb$relation_constraints rc + JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name + JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key + JOIN rdb$index_segments cse ON + cse.rdb$index_name=ix1.rdb$index_name + JOIN rdb$index_segments se + ON se.rdb$index_name=ix2.rdb$index_name + AND se.rdb$field_position=cse.rdb$field_position + WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=? + ORDER BY se.rdb$index_name, se.rdb$field_position + """ + tablename = self.denormalize_name(table_name) + + c = connection.exec_driver_sql(fkqry, ["FOREIGN KEY", tablename]) + fks = util.defaultdict( + lambda: { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + } + ) + + for row in c: + cname = self.normalize_name(row["cname"]) + fk = fks[cname] + if not fk["name"]: + fk["name"] = cname + fk["referred_table"] = self.normalize_name(row["targetrname"]) + fk["constrained_columns"].append(self.normalize_name(row["fname"])) + fk["referred_columns"].append( + self.normalize_name(row["targetfname"]) + ) + return list(fks.values()) + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + qry = """ + SELECT ix.rdb$index_name AS index_name, + ix.rdb$unique_flag AS unique_flag, + ic.rdb$field_name AS field_name + FROM rdb$indices ix + JOIN rdb$index_segments ic + ON ix.rdb$index_name=ic.rdb$index_name + LEFT OUTER JOIN rdb$relation_constraints + ON rdb$relation_constraints.rdb$index_name = + ic.rdb$index_name + WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL + AND rdb$relation_constraints.rdb$constraint_type IS NULL + ORDER BY index_name, ic.rdb$field_position + """ + c = connection.exec_driver_sql( + qry, [self.denormalize_name(table_name)] + ) + + indexes = util.defaultdict(dict) + for row in c: + indexrec = indexes[row["index_name"]] + if "name" not in indexrec: + indexrec["name"] = self.normalize_name(row["index_name"]) + indexrec["column_names"] = [] + indexrec["unique"] = bool(row["unique_flag"]) + + indexrec["column_names"].append( + self.normalize_name(row["field_name"]) + ) + + return list(indexes.values()) diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py new file mode 100644 index 0000000..38f4432 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/fdb.py @@ -0,0 +1,112 @@ +# firebird/fdb.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: firebird+fdb + :name: fdb + :dbapi: pyodbc + :connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...] + :url: https://pypi.org/project/fdb/ + + fdb is a kinterbasdb compatible DBAPI for Firebird. + + .. versionchanged:: 0.9 - The fdb dialect is now the default dialect + under the ``firebird://`` URL space, as ``fdb`` is now the official + Python driver for Firebird. + +Arguments +---------- + +The ``fdb`` dialect is based on the +:mod:`sqlalchemy.dialects.firebird.kinterbasdb` dialect, however does not +accept every argument that Kinterbasdb does. + +* ``enable_rowcount`` - True by default, setting this to False disables + the usage of "cursor.rowcount" with the + Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically + after any UPDATE or DELETE statement. When disabled, SQLAlchemy's + CursorResult will return -1 for result.rowcount. The rationale here is + that Kinterbasdb requires a second round trip to the database when + .rowcount is called - since SQLA's resultproxy automatically closes + the cursor after a non-result-returning statement, rowcount must be + called, if at all, before the result object is returned. Additionally, + cursor.rowcount may not return correct results with older versions + of Firebird, and setting this flag to False will also cause the + SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a + per-execution basis using the ``enable_rowcount`` option with + :meth:`_engine.Connection.execution_options`:: + + conn = engine.connect().execution_options(enable_rowcount=True) + r = conn.execute(stmt) + print(r.rowcount) + +* ``retaining`` - False by default. Setting this to True will pass the + ``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()`` + methods of the DBAPI connection, which can improve performance in some + situations, but apparently with significant caveats. + Please read the fdb and/or kinterbasdb DBAPI documentation in order to + understand the implications of this flag. + + .. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``. + In 0.8 it defaulted to ``True``. + + .. seealso:: + + https://pythonhosted.org/fdb/usage-guide.html#retaining-transactions + - information on the "retaining" flag. + +""" # noqa + +from .kinterbasdb import FBDialect_kinterbasdb +from ... import util + + +class FBDialect_fdb(FBDialect_kinterbasdb): + supports_statement_cache = True + + def __init__(self, enable_rowcount=True, retaining=False, **kwargs): + super(FBDialect_fdb, self).__init__( + enable_rowcount=enable_rowcount, retaining=retaining, **kwargs + ) + + @classmethod + def dbapi(cls): + return __import__("fdb") + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if opts.get("port"): + opts["host"] = "%s/%s" % (opts["host"], opts["port"]) + del opts["port"] + opts.update(url.query) + + util.coerce_kw_type(opts, "type_conv", int) + + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. + + isc_info_firebird_version = 103 + fbconn = connection.connection + + version = fbconn.db_info(isc_info_firebird_version) + + return self._parse_version_info(version) + + +dialect = FBDialect_fdb diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py new file mode 100644 index 0000000..b999404 --- /dev/null +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -0,0 +1,202 @@ +# firebird/kinterbasdb.py +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +.. dialect:: firebird+kinterbasdb + :name: kinterbasdb + :dbapi: kinterbasdb + :connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...] + :url: https://firebirdsql.org/index.php?op=devel&sub=python + +Arguments +---------- + +The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining`` +arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect. +In addition, it also accepts the following: + +* ``type_conv`` - select the kind of mapping done on the types: by default + SQLAlchemy uses 200 with Unicode, datetime and decimal support. See + the linked documents below for further information. + +* ``concurrency_level`` - set the backend policy with regards to threading + issues: by default SQLAlchemy uses policy 1. See the linked documents + below for further information. + +.. seealso:: + + https://sourceforge.net/projects/kinterbasdb + + https://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation + + https://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency + +""" # noqa + +import decimal +from re import match + +from .base import FBDialect +from .base import FBExecutionContext +from ... import types as sqltypes +from ... import util + + +class _kinterbasdb_numeric(object): + def bind_processor(self, dialect): + def process(value): + if isinstance(value, decimal.Decimal): + return str(value) + else: + return value + + return process + + +class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric): + pass + + +class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float): + pass + + +class FBExecutionContext_kinterbasdb(FBExecutionContext): + @property + def rowcount(self): + if self.execution_options.get( + "enable_rowcount", self.dialect.enable_rowcount + ): + return self.cursor.rowcount + else: + return -1 + + +class FBDialect_kinterbasdb(FBDialect): + driver = "kinterbasdb" + supports_statement_cache = True + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + execution_ctx_cls = FBExecutionContext_kinterbasdb + + supports_native_decimal = True + + colspecs = util.update_copy( + FBDialect.colspecs, + { + sqltypes.Numeric: _FBNumeric_kinterbasdb, + sqltypes.Float: _FBFloat_kinterbasdb, + }, + ) + + def __init__( + self, + type_conv=200, + concurrency_level=1, + enable_rowcount=True, + retaining=False, + **kwargs + ): + super(FBDialect_kinterbasdb, self).__init__(**kwargs) + self.enable_rowcount = enable_rowcount + self.type_conv = type_conv + self.concurrency_level = concurrency_level + self.retaining = retaining + if enable_rowcount: + self.supports_sane_rowcount = True + + @classmethod + def dbapi(cls): + return __import__("kinterbasdb") + + def do_execute(self, cursor, statement, parameters, context=None): + # kinterbase does not accept a None, but wants an empty list + # when there are no arguments. + cursor.execute(statement, parameters or []) + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback(self.retaining) + + def do_commit(self, dbapi_connection): + dbapi_connection.commit(self.retaining) + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if opts.get("port"): + opts["host"] = "%s/%s" % (opts["host"], opts["port"]) + del opts["port"] + opts.update(url.query) + + util.coerce_kw_type(opts, "type_conv", int) + + type_conv = opts.pop("type_conv", self.type_conv) + concurrency_level = opts.pop( + "concurrency_level", self.concurrency_level + ) + + if self.dbapi is not None: + initialized = getattr(self.dbapi, "initialized", None) + if initialized is None: + # CVS rev 1.96 changed the name of the attribute: + # https://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/ + # Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96 + initialized = getattr(self.dbapi, "_initialized", False) + if not initialized: + self.dbapi.init( + type_conv=type_conv, concurrency_level=concurrency_level + ) + return ([], opts) + + def _get_server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. + + fbconn = connection.connection + version = fbconn.server_version + + return self._parse_version_info(version) + + def _parse_version_info(self, version): + m = match( + r"\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?", version + ) + if not m: + raise AssertionError( + "Could not determine version from string '%s'" % version + ) + + if m.group(5) != None: + return tuple([int(x) for x in m.group(6, 7, 4)] + ["firebird"]) + else: + return tuple([int(x) for x in m.group(1, 2, 3)] + ["interbase"]) + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): + msg = str(e) + return ( + "Error writing data to the connection" in msg + or "Unable to complete network request to host" in msg + or "Invalid connection state" in msg + or "Invalid cursor state" in msg + or "connection shutdown" in msg + ) + else: + return False + + +dialect = FBDialect_kinterbasdb |