first add files

This commit is contained in:
2023-10-08 20:59:00 +08:00
parent b494be364b
commit 1dac226337
991 changed files with 368151 additions and 40 deletions

View File

@@ -0,0 +1,72 @@
# dialects/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
__all__ = (
"firebird",
"mssql",
"mysql",
"oracle",
"postgresql",
"sqlite",
"sybase",
)
from .. import util
def _auto_fn(name):
"""default dialect importer.
plugs into the :class:`.PluginLoader`
as a first-hit system.
"""
if "." in name:
dialect, driver = name.split(".")
else:
dialect = name
driver = "base"
try:
if dialect == "firebird":
try:
module = __import__("sqlalchemy_firebird")
except ImportError:
module = __import__("sqlalchemy.dialects.firebird").dialects
module = getattr(module, dialect)
elif dialect == "sybase":
try:
module = __import__("sqlalchemy_sybase")
except ImportError:
module = __import__("sqlalchemy.dialects.sybase").dialects
module = getattr(module, dialect)
elif dialect == "mariadb":
# it's "OK" for us to hardcode here since _auto_fn is already
# hardcoded. if mysql / mariadb etc were third party dialects
# they would just publish all the entrypoints, which would actually
# look much nicer.
module = __import__(
"sqlalchemy.dialects.mysql.mariadb"
).dialects.mysql.mariadb
return module.loader(driver)
else:
module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects
module = getattr(module, dialect)
except ImportError:
return None
if hasattr(module, driver):
module = getattr(module, driver)
return lambda: module.dialect
else:
return None
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
plugins = util.PluginLoader("sqlalchemy.plugins")

View File

@@ -0,0 +1,41 @@
# firebird/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from sqlalchemy.dialects.firebird.base import BIGINT
from sqlalchemy.dialects.firebird.base import BLOB
from sqlalchemy.dialects.firebird.base import CHAR
from sqlalchemy.dialects.firebird.base import DATE
from sqlalchemy.dialects.firebird.base import FLOAT
from sqlalchemy.dialects.firebird.base import NUMERIC
from sqlalchemy.dialects.firebird.base import SMALLINT
from sqlalchemy.dialects.firebird.base import TEXT
from sqlalchemy.dialects.firebird.base import TIME
from sqlalchemy.dialects.firebird.base import TIMESTAMP
from sqlalchemy.dialects.firebird.base import VARCHAR
from . import base # noqa
from . import fdb # noqa
from . import kinterbasdb # noqa
base.dialect = dialect = fdb.dialect
__all__ = (
"SMALLINT",
"BIGINT",
"FLOAT",
"FLOAT",
"DATE",
"TIME",
"TEXT",
"NUMERIC",
"FLOAT",
"TIMESTAMP",
"VARCHAR",
"CHAR",
"BLOB",
"dialect",
)

View File

@@ -0,0 +1,989 @@
# firebird/base.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: firebird
:name: Firebird
.. note::
The Firebird dialect within SQLAlchemy **is not currently supported**.
It is not tested within continuous integration and is likely to have
many issues and caveats not currently handled. Consider using the
`external dialect <https://github.com/pauldex/sqlalchemy-firebird>`_
instead.
.. deprecated:: 1.4 The internal Firebird dialect is deprecated and will be
removed in a future version. Use the external dialect.
Firebird Dialects
-----------------
Firebird offers two distinct dialects_ (not to be confused with a
SQLAlchemy ``Dialect``):
dialect 1
This is the old syntax and behaviour, inherited from Interbase pre-6.0.
dialect 3
This is the newer and supported syntax, introduced in Interbase 6.0.
The SQLAlchemy Firebird dialect detects these versions and
adjusts its representation of SQL accordingly. However,
support for dialect 1 is not well tested and probably has
incompatibilities.
Locking Behavior
----------------
Firebird locks tables aggressively. For this reason, a DROP TABLE may
hang until other transactions are released. SQLAlchemy does its best
to release transactions as quickly as possible. The most common cause
of hanging transactions is a non-fully consumed result set, i.e.::
result = engine.execute(text("select * from table"))
row = result.fetchone()
return
Where above, the ``CursorResult`` has not been fully consumed. The
connection will be returned to the pool and the transactional state
rolled back once the Python garbage collector reclaims the objects
which hold onto the connection, which often occurs asynchronously.
The above use case can be alleviated by calling ``first()`` on the
``CursorResult`` which will fetch the first row and immediately close
all remaining cursor/connection resources.
RETURNING support
-----------------
Firebird 2.0 supports returning a result set from inserts, and 2.1
extends that to deletes and updates. This is generically exposed by
the SQLAlchemy ``returning()`` method, such as::
# INSERT..RETURNING
result = table.insert().returning(table.c.col1, table.c.col2).\
values(name='foo')
print(result.fetchall())
# UPDATE..RETURNING
raises = empl.update().returning(empl.c.id, empl.c.salary).\
where(empl.c.sales>100).\
values(dict(salary=empl.c.salary * 1.1))
print(raises.fetchall())
.. _dialects: https://mc-computing.com/Databases/Firebird/SQL_Dialect.html
"""
import datetime
from sqlalchemy import exc
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy import util
from sqlalchemy.engine import default
from sqlalchemy.engine import reflection
from sqlalchemy.sql import compiler
from sqlalchemy.sql import expression
from sqlalchemy.types import BIGINT
from sqlalchemy.types import BLOB
from sqlalchemy.types import DATE
from sqlalchemy.types import FLOAT
from sqlalchemy.types import INTEGER
from sqlalchemy.types import Integer
from sqlalchemy.types import NUMERIC
from sqlalchemy.types import SMALLINT
from sqlalchemy.types import TEXT
from sqlalchemy.types import TIME
from sqlalchemy.types import TIMESTAMP
RESERVED_WORDS = set(
[
"active",
"add",
"admin",
"after",
"all",
"alter",
"and",
"any",
"as",
"asc",
"ascending",
"at",
"auto",
"avg",
"before",
"begin",
"between",
"bigint",
"bit_length",
"blob",
"both",
"by",
"case",
"cast",
"char",
"character",
"character_length",
"char_length",
"check",
"close",
"collate",
"column",
"commit",
"committed",
"computed",
"conditional",
"connect",
"constraint",
"containing",
"count",
"create",
"cross",
"cstring",
"current",
"current_connection",
"current_date",
"current_role",
"current_time",
"current_timestamp",
"current_transaction",
"current_user",
"cursor",
"database",
"date",
"day",
"dec",
"decimal",
"declare",
"default",
"delete",
"desc",
"descending",
"disconnect",
"distinct",
"do",
"domain",
"double",
"drop",
"else",
"end",
"entry_point",
"escape",
"exception",
"execute",
"exists",
"exit",
"external",
"extract",
"fetch",
"file",
"filter",
"float",
"for",
"foreign",
"from",
"full",
"function",
"gdscode",
"generator",
"gen_id",
"global",
"grant",
"group",
"having",
"hour",
"if",
"in",
"inactive",
"index",
"inner",
"input_type",
"insensitive",
"insert",
"int",
"integer",
"into",
"is",
"isolation",
"join",
"key",
"leading",
"left",
"length",
"level",
"like",
"long",
"lower",
"manual",
"max",
"maximum_segment",
"merge",
"min",
"minute",
"module_name",
"month",
"names",
"national",
"natural",
"nchar",
"no",
"not",
"null",
"numeric",
"octet_length",
"of",
"on",
"only",
"open",
"option",
"or",
"order",
"outer",
"output_type",
"overflow",
"page",
"pages",
"page_size",
"parameter",
"password",
"plan",
"position",
"post_event",
"precision",
"primary",
"privileges",
"procedure",
"protected",
"rdb$db_key",
"read",
"real",
"record_version",
"recreate",
"recursive",
"references",
"release",
"reserv",
"reserving",
"retain",
"returning_values",
"returns",
"revoke",
"right",
"rollback",
"rows",
"row_count",
"savepoint",
"schema",
"second",
"segment",
"select",
"sensitive",
"set",
"shadow",
"shared",
"singular",
"size",
"smallint",
"snapshot",
"some",
"sort",
"sqlcode",
"stability",
"start",
"starting",
"starts",
"statistics",
"sub_type",
"sum",
"suspend",
"table",
"then",
"time",
"timestamp",
"to",
"trailing",
"transaction",
"trigger",
"trim",
"uncommitted",
"union",
"unique",
"update",
"upper",
"user",
"using",
"value",
"values",
"varchar",
"variable",
"varying",
"view",
"wait",
"when",
"where",
"while",
"with",
"work",
"write",
"year",
]
)
class _StringType(sqltypes.String):
"""Base for Firebird string types."""
def __init__(self, charset=None, **kw):
self.charset = charset
super(_StringType, self).__init__(**kw)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""Firebird VARCHAR type"""
__visit_name__ = "VARCHAR"
def __init__(self, length=None, **kwargs):
super(VARCHAR, self).__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""Firebird CHAR type"""
__visit_name__ = "CHAR"
def __init__(self, length=None, **kwargs):
super(CHAR, self).__init__(length=length, **kwargs)
class _FBDateTime(sqltypes.DateTime):
def bind_processor(self, dialect):
def process(value):
if type(value) == datetime.date:
return datetime.datetime(value.year, value.month, value.day)
else:
return value
return process
colspecs = {sqltypes.DateTime: _FBDateTime}
ischema_names = {
"SHORT": SMALLINT,
"LONG": INTEGER,
"QUAD": FLOAT,
"FLOAT": FLOAT,
"DATE": DATE,
"TIME": TIME,
"TEXT": TEXT,
"INT64": BIGINT,
"DOUBLE": FLOAT,
"TIMESTAMP": TIMESTAMP,
"VARYING": VARCHAR,
"CSTRING": CHAR,
"BLOB": BLOB,
}
# TODO: date conversion types (should be implemented as _FBDateTime,
# _FBDate, etc. as bind/result functionality is required)
class FBTypeCompiler(compiler.GenericTypeCompiler):
def visit_boolean(self, type_, **kw):
return self.visit_SMALLINT(type_, **kw)
def visit_datetime(self, type_, **kw):
return self.visit_TIMESTAMP(type_, **kw)
def visit_TEXT(self, type_, **kw):
return "BLOB SUB_TYPE 1"
def visit_BLOB(self, type_, **kw):
return "BLOB SUB_TYPE 0"
def _extend_string(self, type_, basic):
charset = getattr(type_, "charset", None)
if charset is None:
return basic
else:
return "%s CHARACTER SET %s" % (basic, charset)
def visit_CHAR(self, type_, **kw):
basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw)
return self._extend_string(type_, basic)
def visit_VARCHAR(self, type_, **kw):
if not type_.length:
raise exc.CompileError(
"VARCHAR requires a length on dialect %s" % self.dialect.name
)
basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw)
return self._extend_string(type_, basic)
class FBCompiler(sql.compiler.SQLCompiler):
"""Firebird specific idiosyncrasies"""
ansi_bind_rules = True
# def visit_contains_op_binary(self, binary, operator, **kw):
# cant use CONTAINING b.c. it's case insensitive.
# def visit_not_contains_op_binary(self, binary, operator, **kw):
# cant use NOT CONTAINING b.c. it's case insensitive.
def visit_now_func(self, fn, **kw):
return "CURRENT_TIMESTAMP"
def visit_startswith_op_binary(self, binary, operator, **kw):
return "%s STARTING WITH %s" % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw),
)
def visit_not_startswith_op_binary(self, binary, operator, **kw):
return "%s NOT STARTING WITH %s" % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw),
)
def visit_mod_binary(self, binary, operator, **kw):
return "mod(%s, %s)" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
)
def visit_alias(self, alias, asfrom=False, **kwargs):
if self.dialect._version_two:
return super(FBCompiler, self).visit_alias(
alias, asfrom=asfrom, **kwargs
)
else:
# Override to not use the AS keyword which FB 1.5 does not like
if asfrom:
alias_name = (
isinstance(alias.name, expression._truncated_label)
and self._truncated_identifier("alias", alias.name)
or alias.name
)
return (
self.process(alias.element, asfrom=asfrom, **kwargs)
+ " "
+ self.preparer.format_alias(alias, alias_name)
)
else:
return self.process(alias.element, **kwargs)
def visit_substring_func(self, func, **kw):
s = self.process(func.clauses.clauses[0])
start = self.process(func.clauses.clauses[1])
if len(func.clauses.clauses) > 2:
length = self.process(func.clauses.clauses[2])
return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length)
else:
return "SUBSTRING(%s FROM %s)" % (s, start)
def visit_length_func(self, function, **kw):
if self.dialect._version_two:
return "char_length" + self.function_argspec(function)
else:
return "strlen" + self.function_argspec(function)
visit_char_length_func = visit_length_func
def function_argspec(self, func, **kw):
# TODO: this probably will need to be
# narrowed to a fixed list, some no-arg functions
# may require parens - see similar example in the oracle
# dialect
if func.clauses is not None and len(func.clauses):
return self.process(func.clause_expr, **kw)
else:
return ""
def default_from(self):
return " FROM rdb$database"
def visit_sequence(self, seq, **kw):
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
def get_select_precolumns(self, select, **kw):
"""Called when building a ``SELECT`` statement, position is just
before column list Firebird puts the limit and offset right
after the ``SELECT``...
"""
result = ""
if select._limit_clause is not None:
result += "FIRST %s " % self.process(select._limit_clause, **kw)
if select._offset_clause is not None:
result += "SKIP %s " % self.process(select._offset_clause, **kw)
result += super(FBCompiler, self).get_select_precolumns(select, **kw)
return result
def limit_clause(self, select, **kw):
"""Already taken care of in the `get_select_precolumns` method."""
return ""
def returning_clause(self, stmt, returning_cols):
columns = [
self._label_returning_column(stmt, c)
for c in expression._select_iterables(returning_cols)
]
return "RETURNING " + ", ".join(columns)
class FBDDLCompiler(sql.compiler.DDLCompiler):
"""Firebird syntactic idiosyncrasies"""
def visit_create_sequence(self, create):
"""Generate a ``CREATE GENERATOR`` statement for the sequence."""
# no syntax for these
# https://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html
if create.element.start is not None:
raise NotImplementedError(
"Firebird SEQUENCE doesn't support START WITH"
)
if create.element.increment is not None:
raise NotImplementedError(
"Firebird SEQUENCE doesn't support INCREMENT BY"
)
if self.dialect._version_two:
return "CREATE SEQUENCE %s" % self.preparer.format_sequence(
create.element
)
else:
return "CREATE GENERATOR %s" % self.preparer.format_sequence(
create.element
)
def visit_drop_sequence(self, drop):
"""Generate a ``DROP GENERATOR`` statement for the sequence."""
if self.dialect._version_two:
return "DROP SEQUENCE %s" % self.preparer.format_sequence(
drop.element
)
else:
return "DROP GENERATOR %s" % self.preparer.format_sequence(
drop.element
)
def visit_computed_column(self, generated):
if generated.persisted is not None:
raise exc.CompileError(
"Firebird computed columns do not support a persistence "
"method setting; set the 'persisted' flag to None for "
"Firebird support."
)
return "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
generated.sqltext, include_table=False, literal_binds=True
)
class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
"""Install Firebird specific reserved words."""
reserved_words = RESERVED_WORDS
illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union(
["_"]
)
def __init__(self, dialect):
super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
class FBExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
"""Get the next value from the sequence using ``gen_id()``."""
return self._execute_scalar(
"SELECT gen_id(%s, 1) FROM rdb$database"
% self.identifier_preparer.format_sequence(seq),
type_,
)
class FBDialect(default.DefaultDialect):
"""Firebird dialect"""
name = "firebird"
supports_statement_cache = True
max_identifier_length = 31
supports_sequences = True
sequences_optional = False
supports_default_values = True
postfetch_lastrowid = False
supports_native_boolean = False
requires_name_normalize = True
supports_empty_insert = False
statement_compiler = FBCompiler
ddl_compiler = FBDDLCompiler
preparer = FBIdentifierPreparer
type_compiler = FBTypeCompiler
execution_ctx_cls = FBExecutionContext
colspecs = colspecs
ischema_names = ischema_names
construct_arguments = []
# defaults to dialect ver. 3,
# will be autodetected off upon
# first connect
_version_two = True
def __init__(self, *args, **kwargs):
util.warn_deprecated(
"The firebird dialect is deprecated and will be removed "
"in a future version. This dialect is superseded by the external "
"dialect https://github.com/pauldex/sqlalchemy-firebird.",
version="1.4",
)
super(FBDialect, self).__init__(*args, **kwargs)
def initialize(self, connection):
super(FBDialect, self).initialize(connection)
self._version_two = (
"firebird" in self.server_version_info
and self.server_version_info >= (2,)
) or (
"interbase" in self.server_version_info
and self.server_version_info >= (6,)
)
if not self._version_two:
# TODO: whatever other pre < 2.0 stuff goes here
self.ischema_names = ischema_names.copy()
self.ischema_names["TIMESTAMP"] = sqltypes.DATE
self.colspecs = {sqltypes.DateTime: sqltypes.DATE}
self.implicit_returning = self._version_two and self.__dict__.get(
"implicit_returning", True
)
def has_table(self, connection, table_name, schema=None):
"""Return ``True`` if the given table exists, ignoring
the `schema`."""
self._ensure_has_table_connection(connection)
tblqry = """
SELECT 1 AS has_table FROM rdb$database
WHERE EXISTS (SELECT rdb$relation_name
FROM rdb$relations
WHERE rdb$relation_name=?)
"""
c = connection.exec_driver_sql(
tblqry, [self.denormalize_name(table_name)]
)
return c.first() is not None
def has_sequence(self, connection, sequence_name, schema=None):
"""Return ``True`` if the given sequence (generator) exists."""
genqry = """
SELECT 1 AS has_sequence FROM rdb$database
WHERE EXISTS (SELECT rdb$generator_name
FROM rdb$generators
WHERE rdb$generator_name=?)
"""
c = connection.exec_driver_sql(
genqry, [self.denormalize_name(sequence_name)]
)
return c.first() is not None
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
# there are two queries commonly mentioned for this.
# this one, using view_blr, is at the Firebird FAQ among other places:
# https://www.firebirdfaq.org/faq174/
s = """
select rdb$relation_name
from rdb$relations
where rdb$view_blr is null
and (rdb$system_flag is null or rdb$system_flag = 0);
"""
# the other query is this one. It's not clear if there's really
# any difference between these two. This link:
# https://www.alberton.info/firebird_sql_meta_info.html#.Ur3vXfZGni8
# states them as interchangeable. Some discussion at [ticket:2898]
# SELECT DISTINCT rdb$relation_name
# FROM rdb$relation_fields
# WHERE rdb$system_flag=0 AND rdb$view_context IS NULL
return [
self.normalize_name(row[0])
for row in connection.exec_driver_sql(s)
]
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
# see https://www.firebirdfaq.org/faq174/
s = """
select rdb$relation_name
from rdb$relations
where rdb$view_blr is not null
and (rdb$system_flag is null or rdb$system_flag = 0);
"""
return [
self.normalize_name(row[0])
for row in connection.exec_driver_sql(s)
]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
qry = """
SELECT rdb$view_source AS view_source
FROM rdb$relations
WHERE rdb$relation_name=?
"""
rp = connection.exec_driver_sql(
qry, [self.denormalize_name(view_name)]
)
row = rp.first()
if row:
return row["view_source"]
else:
return None
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# Query to extract the PK/FK constrained fields of the given table
keyqry = """
SELECT se.rdb$field_name AS fname
FROM rdb$relation_constraints rc
JOIN rdb$index_segments se ON rc.rdb$index_name=se.rdb$index_name
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
"""
tablename = self.denormalize_name(table_name)
# get primary key fields
c = connection.exec_driver_sql(keyqry, ["PRIMARY KEY", tablename])
pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()]
return {"constrained_columns": pkfields, "name": None}
@reflection.cache
def get_column_sequence(
self, connection, table_name, column_name, schema=None, **kw
):
tablename = self.denormalize_name(table_name)
colname = self.denormalize_name(column_name)
# Heuristic-query to determine the generator associated to a PK field
genqry = """
SELECT trigdep.rdb$depended_on_name AS fgenerator
FROM rdb$dependencies tabdep
JOIN rdb$dependencies trigdep
ON tabdep.rdb$dependent_name=trigdep.rdb$dependent_name
AND trigdep.rdb$depended_on_type=14
AND trigdep.rdb$dependent_type=2
JOIN rdb$triggers trig ON
trig.rdb$trigger_name=tabdep.rdb$dependent_name
WHERE tabdep.rdb$depended_on_name=?
AND tabdep.rdb$depended_on_type=0
AND trig.rdb$trigger_type=1
AND tabdep.rdb$field_name=?
AND (SELECT count(*)
FROM rdb$dependencies trigdep2
WHERE trigdep2.rdb$dependent_name = trigdep.rdb$dependent_name) = 2
"""
genr = connection.exec_driver_sql(genqry, [tablename, colname]).first()
if genr is not None:
return dict(name=self.normalize_name(genr["fgenerator"]))
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
# Query to extract the details of all the fields of the given table
tblqry = """
SELECT r.rdb$field_name AS fname,
r.rdb$null_flag AS null_flag,
t.rdb$type_name AS ftype,
f.rdb$field_sub_type AS stype,
f.rdb$field_length/
COALESCE(cs.rdb$bytes_per_character,1) AS flen,
f.rdb$field_precision AS fprec,
f.rdb$field_scale AS fscale,
COALESCE(r.rdb$default_source,
f.rdb$default_source) AS fdefault
FROM rdb$relation_fields r
JOIN rdb$fields f ON r.rdb$field_source=f.rdb$field_name
JOIN rdb$types t
ON t.rdb$type=f.rdb$field_type AND
t.rdb$field_name='RDB$FIELD_TYPE'
LEFT JOIN rdb$character_sets cs ON
f.rdb$character_set_id=cs.rdb$character_set_id
WHERE f.rdb$system_flag=0 AND r.rdb$relation_name=?
ORDER BY r.rdb$field_position
"""
# get the PK, used to determine the eventual associated sequence
pk_constraint = self.get_pk_constraint(connection, table_name)
pkey_cols = pk_constraint["constrained_columns"]
tablename = self.denormalize_name(table_name)
# get all of the fields for this table
c = connection.exec_driver_sql(tblqry, [tablename])
cols = []
while True:
row = c.fetchone()
if row is None:
break
name = self.normalize_name(row["fname"])
orig_colname = row["fname"]
# get the data type
colspec = row["ftype"].rstrip()
coltype = self.ischema_names.get(colspec)
if coltype is None:
util.warn(
"Did not recognize type '%s' of column '%s'"
% (colspec, name)
)
coltype = sqltypes.NULLTYPE
elif issubclass(coltype, Integer) and row["fprec"] != 0:
coltype = NUMERIC(
precision=row["fprec"], scale=row["fscale"] * -1
)
elif colspec in ("VARYING", "CSTRING"):
coltype = coltype(row["flen"])
elif colspec == "TEXT":
coltype = TEXT(row["flen"])
elif colspec == "BLOB":
if row["stype"] == 1:
coltype = TEXT()
else:
coltype = BLOB()
else:
coltype = coltype()
# does it have a default value?
defvalue = None
if row["fdefault"] is not None:
# the value comes down as "DEFAULT 'value'": there may be
# more than one whitespace around the "DEFAULT" keyword
# and it may also be lower case
# (see also https://tracker.firebirdsql.org/browse/CORE-356)
defexpr = row["fdefault"].lstrip()
assert defexpr[:8].rstrip().upper() == "DEFAULT", (
"Unrecognized default value: %s" % defexpr
)
defvalue = defexpr[8:].strip()
if defvalue == "NULL":
# Redundant
defvalue = None
col_d = {
"name": name,
"type": coltype,
"nullable": not bool(row["null_flag"]),
"default": defvalue,
"autoincrement": "auto",
}
if orig_colname.lower() == orig_colname:
col_d["quote"] = True
# if the PK is a single field, try to see if its linked to
# a sequence thru a trigger
if len(pkey_cols) == 1 and name == pkey_cols[0]:
seq_d = self.get_column_sequence(connection, tablename, name)
if seq_d is not None:
col_d["sequence"] = seq_d
cols.append(col_d)
return cols
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# Query to extract the details of each UK/FK of the given table
fkqry = """
SELECT rc.rdb$constraint_name AS cname,
cse.rdb$field_name AS fname,
ix2.rdb$relation_name AS targetrname,
se.rdb$field_name AS targetfname
FROM rdb$relation_constraints rc
JOIN rdb$indices ix1 ON ix1.rdb$index_name=rc.rdb$index_name
JOIN rdb$indices ix2 ON ix2.rdb$index_name=ix1.rdb$foreign_key
JOIN rdb$index_segments cse ON
cse.rdb$index_name=ix1.rdb$index_name
JOIN rdb$index_segments se
ON se.rdb$index_name=ix2.rdb$index_name
AND se.rdb$field_position=cse.rdb$field_position
WHERE rc.rdb$constraint_type=? AND rc.rdb$relation_name=?
ORDER BY se.rdb$index_name, se.rdb$field_position
"""
tablename = self.denormalize_name(table_name)
c = connection.exec_driver_sql(fkqry, ["FOREIGN KEY", tablename])
fks = util.defaultdict(
lambda: {
"name": None,
"constrained_columns": [],
"referred_schema": None,
"referred_table": None,
"referred_columns": [],
}
)
for row in c:
cname = self.normalize_name(row["cname"])
fk = fks[cname]
if not fk["name"]:
fk["name"] = cname
fk["referred_table"] = self.normalize_name(row["targetrname"])
fk["constrained_columns"].append(self.normalize_name(row["fname"]))
fk["referred_columns"].append(
self.normalize_name(row["targetfname"])
)
return list(fks.values())
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
qry = """
SELECT ix.rdb$index_name AS index_name,
ix.rdb$unique_flag AS unique_flag,
ic.rdb$field_name AS field_name
FROM rdb$indices ix
JOIN rdb$index_segments ic
ON ix.rdb$index_name=ic.rdb$index_name
LEFT OUTER JOIN rdb$relation_constraints
ON rdb$relation_constraints.rdb$index_name =
ic.rdb$index_name
WHERE ix.rdb$relation_name=? AND ix.rdb$foreign_key IS NULL
AND rdb$relation_constraints.rdb$constraint_type IS NULL
ORDER BY index_name, ic.rdb$field_position
"""
c = connection.exec_driver_sql(
qry, [self.denormalize_name(table_name)]
)
indexes = util.defaultdict(dict)
for row in c:
indexrec = indexes[row["index_name"]]
if "name" not in indexrec:
indexrec["name"] = self.normalize_name(row["index_name"])
indexrec["column_names"] = []
indexrec["unique"] = bool(row["unique_flag"])
indexrec["column_names"].append(
self.normalize_name(row["field_name"])
)
return list(indexes.values())

View File

@@ -0,0 +1,112 @@
# firebird/fdb.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird+fdb
:name: fdb
:dbapi: pyodbc
:connectstring: firebird+fdb://user:password@host:port/path/to/db[?key=value&key=value...]
:url: https://pypi.org/project/fdb/
fdb is a kinterbasdb compatible DBAPI for Firebird.
.. versionchanged:: 0.9 - The fdb dialect is now the default dialect
under the ``firebird://`` URL space, as ``fdb`` is now the official
Python driver for Firebird.
Arguments
----------
The ``fdb`` dialect is based on the
:mod:`sqlalchemy.dialects.firebird.kinterbasdb` dialect, however does not
accept every argument that Kinterbasdb does.
* ``enable_rowcount`` - True by default, setting this to False disables
the usage of "cursor.rowcount" with the
Kinterbasdb dialect, which SQLAlchemy ordinarily calls upon automatically
after any UPDATE or DELETE statement. When disabled, SQLAlchemy's
CursorResult will return -1 for result.rowcount. The rationale here is
that Kinterbasdb requires a second round trip to the database when
.rowcount is called - since SQLA's resultproxy automatically closes
the cursor after a non-result-returning statement, rowcount must be
called, if at all, before the result object is returned. Additionally,
cursor.rowcount may not return correct results with older versions
of Firebird, and setting this flag to False will also cause the
SQLAlchemy ORM to ignore its usage. The behavior can also be controlled on a
per-execution basis using the ``enable_rowcount`` option with
:meth:`_engine.Connection.execution_options`::
conn = engine.connect().execution_options(enable_rowcount=True)
r = conn.execute(stmt)
print(r.rowcount)
* ``retaining`` - False by default. Setting this to True will pass the
``retaining=True`` keyword argument to the ``.commit()`` and ``.rollback()``
methods of the DBAPI connection, which can improve performance in some
situations, but apparently with significant caveats.
Please read the fdb and/or kinterbasdb DBAPI documentation in order to
understand the implications of this flag.
.. versionchanged:: 0.9.0 - the ``retaining`` flag defaults to ``False``.
In 0.8 it defaulted to ``True``.
.. seealso::
https://pythonhosted.org/fdb/usage-guide.html#retaining-transactions
- information on the "retaining" flag.
""" # noqa
from .kinterbasdb import FBDialect_kinterbasdb
from ... import util
class FBDialect_fdb(FBDialect_kinterbasdb):
supports_statement_cache = True
def __init__(self, enable_rowcount=True, retaining=False, **kwargs):
super(FBDialect_fdb, self).__init__(
enable_rowcount=enable_rowcount, retaining=retaining, **kwargs
)
@classmethod
def dbapi(cls):
return __import__("fdb")
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
if opts.get("port"):
opts["host"] = "%s/%s" % (opts["host"], opts["port"])
del opts["port"]
opts.update(url.query)
util.coerce_kw_type(opts, "type_conv", int)
return ([], opts)
def _get_server_version_info(self, connection):
"""Get the version of the Firebird server used by a connection.
Returns a tuple of (`major`, `minor`, `build`), three integers
representing the version of the attached server.
"""
# This is the simpler approach (the other uses the services api),
# that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old
# Interbase signature.
isc_info_firebird_version = 103
fbconn = connection.connection
version = fbconn.db_info(isc_info_firebird_version)
return self._parse_version_info(version)
dialect = FBDialect_fdb

View File

@@ -0,0 +1,202 @@
# firebird/kinterbasdb.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: firebird+kinterbasdb
:name: kinterbasdb
:dbapi: kinterbasdb
:connectstring: firebird+kinterbasdb://user:password@host:port/path/to/db[?key=value&key=value...]
:url: https://firebirdsql.org/index.php?op=devel&sub=python
Arguments
----------
The Kinterbasdb backend accepts the ``enable_rowcount`` and ``retaining``
arguments accepted by the :mod:`sqlalchemy.dialects.firebird.fdb` dialect.
In addition, it also accepts the following:
* ``type_conv`` - select the kind of mapping done on the types: by default
SQLAlchemy uses 200 with Unicode, datetime and decimal support. See
the linked documents below for further information.
* ``concurrency_level`` - set the backend policy with regards to threading
issues: by default SQLAlchemy uses policy 1. See the linked documents
below for further information.
.. seealso::
https://sourceforge.net/projects/kinterbasdb
https://kinterbasdb.sourceforge.net/dist_docs/usage.html#adv_param_conv_dynamic_type_translation
https://kinterbasdb.sourceforge.net/dist_docs/usage.html#special_issue_concurrency
""" # noqa
import decimal
from re import match
from .base import FBDialect
from .base import FBExecutionContext
from ... import types as sqltypes
from ... import util
class _kinterbasdb_numeric(object):
def bind_processor(self, dialect):
def process(value):
if isinstance(value, decimal.Decimal):
return str(value)
else:
return value
return process
class _FBNumeric_kinterbasdb(_kinterbasdb_numeric, sqltypes.Numeric):
pass
class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float):
pass
class FBExecutionContext_kinterbasdb(FBExecutionContext):
@property
def rowcount(self):
if self.execution_options.get(
"enable_rowcount", self.dialect.enable_rowcount
):
return self.cursor.rowcount
else:
return -1
class FBDialect_kinterbasdb(FBDialect):
driver = "kinterbasdb"
supports_statement_cache = True
supports_sane_rowcount = False
supports_sane_multi_rowcount = False
execution_ctx_cls = FBExecutionContext_kinterbasdb
supports_native_decimal = True
colspecs = util.update_copy(
FBDialect.colspecs,
{
sqltypes.Numeric: _FBNumeric_kinterbasdb,
sqltypes.Float: _FBFloat_kinterbasdb,
},
)
def __init__(
self,
type_conv=200,
concurrency_level=1,
enable_rowcount=True,
retaining=False,
**kwargs
):
super(FBDialect_kinterbasdb, self).__init__(**kwargs)
self.enable_rowcount = enable_rowcount
self.type_conv = type_conv
self.concurrency_level = concurrency_level
self.retaining = retaining
if enable_rowcount:
self.supports_sane_rowcount = True
@classmethod
def dbapi(cls):
return __import__("kinterbasdb")
def do_execute(self, cursor, statement, parameters, context=None):
# kinterbase does not accept a None, but wants an empty list
# when there are no arguments.
cursor.execute(statement, parameters or [])
def do_rollback(self, dbapi_connection):
dbapi_connection.rollback(self.retaining)
def do_commit(self, dbapi_connection):
dbapi_connection.commit(self.retaining)
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
if opts.get("port"):
opts["host"] = "%s/%s" % (opts["host"], opts["port"])
del opts["port"]
opts.update(url.query)
util.coerce_kw_type(opts, "type_conv", int)
type_conv = opts.pop("type_conv", self.type_conv)
concurrency_level = opts.pop(
"concurrency_level", self.concurrency_level
)
if self.dbapi is not None:
initialized = getattr(self.dbapi, "initialized", None)
if initialized is None:
# CVS rev 1.96 changed the name of the attribute:
# https://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/
# Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96
initialized = getattr(self.dbapi, "_initialized", False)
if not initialized:
self.dbapi.init(
type_conv=type_conv, concurrency_level=concurrency_level
)
return ([], opts)
def _get_server_version_info(self, connection):
"""Get the version of the Firebird server used by a connection.
Returns a tuple of (`major`, `minor`, `build`), three integers
representing the version of the attached server.
"""
# This is the simpler approach (the other uses the services api),
# that for backward compatibility reasons returns a string like
# LI-V6.3.3.12981 Firebird 2.0
# where the first version is a fake one resembling the old
# Interbase signature.
fbconn = connection.connection
version = fbconn.server_version
return self._parse_version_info(version)
def _parse_version_info(self, version):
m = match(
r"\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?", version
)
if not m:
raise AssertionError(
"Could not determine version from string '%s'" % version
)
if m.group(5) != None:
return tuple([int(x) for x in m.group(6, 7, 4)] + ["firebird"])
else:
return tuple([int(x) for x in m.group(1, 2, 3)] + ["interbase"])
def is_disconnect(self, e, connection, cursor):
if isinstance(
e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
):
msg = str(e)
return (
"Error writing data to the connection" in msg
or "Unable to complete network request to host" in msg
or "Invalid connection state" in msg
or "Invalid cursor state" in msg
or "connection shutdown" in msg
)
else:
return False
dialect = FBDialect_kinterbasdb

View File

@@ -0,0 +1,85 @@
# mssql/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from . import base # noqa
from . import mxodbc # noqa
from . import pymssql # noqa
from . import pyodbc # noqa
from .base import BIGINT
from .base import BINARY
from .base import BIT
from .base import CHAR
from .base import DATE
from .base import DATETIME
from .base import DATETIME2
from .base import DATETIMEOFFSET
from .base import DECIMAL
from .base import FLOAT
from .base import IMAGE
from .base import INTEGER
from .base import JSON
from .base import MONEY
from .base import NCHAR
from .base import NTEXT
from .base import NUMERIC
from .base import NVARCHAR
from .base import REAL
from .base import ROWVERSION
from .base import SMALLDATETIME
from .base import SMALLINT
from .base import SMALLMONEY
from .base import SQL_VARIANT
from .base import TEXT
from .base import TIME
from .base import TIMESTAMP
from .base import TINYINT
from .base import try_cast
from .base import UNIQUEIDENTIFIER
from .base import VARBINARY
from .base import VARCHAR
from .base import XML
base.dialect = dialect = pyodbc.dialect
__all__ = (
"JSON",
"INTEGER",
"BIGINT",
"SMALLINT",
"TINYINT",
"VARCHAR",
"NVARCHAR",
"CHAR",
"NCHAR",
"TEXT",
"NTEXT",
"DECIMAL",
"NUMERIC",
"FLOAT",
"DATETIME",
"DATETIME2",
"DATETIMEOFFSET",
"DATE",
"TIME",
"SMALLDATETIME",
"BINARY",
"VARBINARY",
"BIT",
"REAL",
"IMAGE",
"TIMESTAMP",
"ROWVERSION",
"MONEY",
"SMALLMONEY",
"UNIQUEIDENTIFIER",
"SQL_VARIANT",
"XML",
"dialect",
"try_cast",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,232 @@
# mssql/information_schema.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from ... import cast
from ... import Column
from ... import MetaData
from ... import Table
from ... import util
from ...ext.compiler import compiles
from ...sql import expression
from ...types import Boolean
from ...types import Integer
from ...types import Numeric
from ...types import String
from ...types import TypeDecorator
from ...types import Unicode
ischema = MetaData()
class CoerceUnicode(TypeDecorator):
impl = Unicode
cache_ok = True
def process_bind_param(self, value, dialect):
if util.py2k and isinstance(value, util.binary_type):
value = value.decode(dialect.encoding)
return value
def bind_expression(self, bindvalue):
return _cast_on_2005(bindvalue)
class _cast_on_2005(expression.ColumnElement):
def __init__(self, bindvalue):
self.bindvalue = bindvalue
@compiles(_cast_on_2005)
def _compile(element, compiler, **kw):
from . import base
if (
compiler.dialect.server_version_info is None
or compiler.dialect.server_version_info < base.MS_2005_VERSION
):
return compiler.process(element.bindvalue, **kw)
else:
return compiler.process(cast(element.bindvalue, Unicode), **kw)
schemata = Table(
"SCHEMATA",
ischema,
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
schema="INFORMATION_SCHEMA",
)
tables = Table(
"TABLES",
ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("TABLE_TYPE", CoerceUnicode, key="table_type"),
schema="INFORMATION_SCHEMA",
)
columns = Table(
"COLUMNS",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column(
"CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"),
Column("COLLATION_NAME", String, key="collation_name"),
schema="INFORMATION_SCHEMA",
)
mssql_temp_table_columns = Table(
"COLUMNS",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("IS_NULLABLE", Integer, key="is_nullable"),
Column("DATA_TYPE", String, key="data_type"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
Column(
"CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
),
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
Column("COLUMN_DEFAULT", Integer, key="column_default"),
Column("COLLATION_NAME", String, key="collation_name"),
schema="tempdb.INFORMATION_SCHEMA",
)
constraints = Table(
"TABLE_CONSTRAINTS",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("CONSTRAINT_TYPE", CoerceUnicode, key="constraint_type"),
schema="INFORMATION_SCHEMA",
)
column_constraints = Table(
"CONSTRAINT_COLUMN_USAGE",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
schema="INFORMATION_SCHEMA",
)
key_constraints = Table(
"KEY_COLUMN_USAGE",
ischema,
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
schema="INFORMATION_SCHEMA",
)
ref_constraints = Table(
"REFERENTIAL_CONSTRAINTS",
ischema,
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
# TODO: is CATLOG misspelled ?
Column(
"UNIQUE_CONSTRAINT_CATLOG",
CoerceUnicode,
key="unique_constraint_catalog",
),
Column(
"UNIQUE_CONSTRAINT_SCHEMA",
CoerceUnicode,
key="unique_constraint_schema",
),
Column(
"UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"
),
Column("MATCH_OPTION", String, key="match_option"),
Column("UPDATE_RULE", String, key="update_rule"),
Column("DELETE_RULE", String, key="delete_rule"),
schema="INFORMATION_SCHEMA",
)
views = Table(
"VIEWS",
ischema,
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
Column("CHECK_OPTION", String, key="check_option"),
Column("IS_UPDATABLE", String, key="is_updatable"),
schema="INFORMATION_SCHEMA",
)
computed_columns = Table(
"computed_columns",
ischema,
Column("object_id", Integer),
Column("name", CoerceUnicode),
Column("is_computed", Boolean),
Column("is_persisted", Boolean),
Column("definition", CoerceUnicode),
schema="sys",
)
sequences = Table(
"SEQUENCES",
ischema,
Column("SEQUENCE_CATALOG", CoerceUnicode, key="sequence_catalog"),
Column("SEQUENCE_SCHEMA", CoerceUnicode, key="sequence_schema"),
Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"),
schema="INFORMATION_SCHEMA",
)
class IdentitySqlVariant(TypeDecorator):
r"""This type casts sql_variant columns in the identity_columns view
to numeric. This is required because:
* pyodbc does not support sql_variant
* pymssql under python 2 return the byte representation of the number,
int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the
correct value as string.
"""
impl = Unicode
cache_ok = True
def column_expression(self, colexpr):
return cast(colexpr, Numeric)
identity_columns = Table(
"identity_columns",
ischema,
Column("object_id", Integer),
Column("name", CoerceUnicode),
Column("is_identity", Boolean),
Column("seed_value", IdentitySqlVariant),
Column("increment_value", IdentitySqlVariant),
Column("last_value", IdentitySqlVariant),
Column("is_not_for_replication", Boolean),
schema="sys",
)

View File

@@ -0,0 +1,125 @@
from ... import types as sqltypes
# technically, all the dialect-specific datatypes that don't have any special
# behaviors would be private with names like _MSJson. However, we haven't been
# doing this for mysql.JSON or sqlite.JSON which both have JSON / JSONIndexType
# / JSONPathType in their json.py files, so keep consistent with that
# sub-convention for now. A future change can update them all to be
# package-private at once.
class JSON(sqltypes.JSON):
"""MSSQL JSON type.
MSSQL supports JSON-formatted data as of SQL Server 2016.
The :class:`_mssql.JSON` datatype at the DDL level will represent the
datatype as ``NVARCHAR(max)``, but provides for JSON-level comparison
functions as well as Python coercion behavior.
:class:`_mssql.JSON` is used automatically whenever the base
:class:`_types.JSON` datatype is used against a SQL Server backend.
.. seealso::
:class:`_types.JSON` - main documentation for the generic
cross-platform JSON datatype.
The :class:`_mssql.JSON` type supports persistence of JSON values
as well as the core index operations provided by :class:`_types.JSON`
datatype, by adapting the operations to render the ``JSON_VALUE``
or ``JSON_QUERY`` functions at the database level.
The SQL Server :class:`_mssql.JSON` type necessarily makes use of the
``JSON_QUERY`` and ``JSON_VALUE`` functions when querying for elements
of a JSON object. These two functions have a major restriction in that
they are **mutually exclusive** based on the type of object to be returned.
The ``JSON_QUERY`` function **only** returns a JSON dictionary or list,
but not an individual string, numeric, or boolean element; the
``JSON_VALUE`` function **only** returns an individual string, numeric,
or boolean element. **both functions either return NULL or raise
an error if they are not used against the correct expected value**.
To handle this awkward requirement, indexed access rules are as follows:
1. When extracting a sub element from a JSON that is itself a JSON
dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor
should be used::
stmt = select(
data_table.c.data["some key"].as_json()
).where(
data_table.c.data["some key"].as_json() == {"sub": "structure"}
)
2. When extracting a sub element from a JSON that is a plain boolean,
string, integer, or float, use the appropriate method among
:meth:`_types.JSON.Comparator.as_boolean`,
:meth:`_types.JSON.Comparator.as_string`,
:meth:`_types.JSON.Comparator.as_integer`,
:meth:`_types.JSON.Comparator.as_float`::
stmt = select(
data_table.c.data["some key"].as_string()
).where(
data_table.c.data["some key"].as_string() == "some string"
)
.. versionadded:: 1.4
"""
# note there was a result processor here that was looking for "number",
# but none of the tests seem to exercise it.
# Note: these objects currently match exactly those of MySQL, however since
# these are not generalizable to all JSON implementations, remain separately
# implemented for each dialect.
class _FormatTypeMixin(object):
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
else:
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
"".join(
[
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
for elem in value
]
)
)

View File

@@ -0,0 +1,150 @@
# mssql/mxodbc.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+mxodbc
:name: mxODBC
:dbapi: mxodbc
:connectstring: mssql+mxodbc://<username>:<password>@<dsnname>
:url: https://www.egenix.com/
.. deprecated:: 1.4 The mxODBC DBAPI is deprecated and will be removed
in a future version. Please use one of the supported DBAPIs to
connect to mssql.
Execution Modes
---------------
mxODBC features two styles of statement execution, using the
``cursor.execute()`` and ``cursor.executedirect()`` methods (the second being
an extension to the DBAPI specification). The former makes use of a particular
API call specific to the SQL Server Native Client ODBC driver known
SQLDescribeParam, while the latter does not.
mxODBC apparently only makes repeated use of a single prepared statement
when SQLDescribeParam is used. The advantage to prepared statement reuse is
one of performance. The disadvantage is that SQLDescribeParam has a limited
set of scenarios in which bind parameters are understood, including that they
cannot be placed within the argument lists of function calls, anywhere outside
the FROM, or even within subqueries within the FROM clause - making the usage
of bind parameters within SELECT statements impossible for all but the most
simplistic statements.
For this reason, the mxODBC dialect uses the "native" mode by default only for
INSERT, UPDATE, and DELETE statements, and uses the escaped string mode for
all other statements.
This behavior can be controlled via
:meth:`~sqlalchemy.sql.expression.Executable.execution_options` using the
``native_odbc_execute`` flag with a value of ``True`` or ``False``, where a
value of ``True`` will unconditionally use native bind parameters and a value
of ``False`` will unconditionally use string-escaped parameters.
"""
from .base import _MSDate
from .base import _MSDateTime
from .base import _MSTime
from .base import MSDialect
from .base import VARBINARY
from .pyodbc import _MSNumeric_pyodbc
from .pyodbc import MSExecutionContext_pyodbc
from ... import types as sqltypes
from ...connectors.mxodbc import MxODBCConnector
class _MSNumeric_mxodbc(_MSNumeric_pyodbc):
"""Include pyodbc's numeric processor."""
class _MSDate_mxodbc(_MSDate):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return "%s-%s-%s" % (value.year, value.month, value.day)
else:
return None
return process
class _MSTime_mxodbc(_MSTime):
def bind_processor(self, dialect):
def process(value):
if value is not None:
return "%s:%s:%s" % (value.hour, value.minute, value.second)
else:
return None
return process
class _VARBINARY_mxodbc(VARBINARY):
"""
mxODBC Support for VARBINARY column types.
This handles the special case for null VARBINARY values,
which maps None values to the mx.ODBC.Manager.BinaryNull symbol.
"""
def bind_processor(self, dialect):
if dialect.dbapi is None:
return None
DBAPIBinary = dialect.dbapi.Binary
def process(value):
if value is not None:
return DBAPIBinary(value)
else:
# should pull from mx.ODBC.Manager.BinaryNull
return dialect.dbapi.BinaryNull
return process
class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
"""
The pyodbc execution context is useful for enabling
SELECT SCOPE_IDENTITY in cases where OUTPUT clause
does not work (tables with insert triggers).
"""
# todo - investigate whether the pyodbc execution context
# is really only being used in cases where OUTPUT
# won't work.
class MSDialect_mxodbc(MxODBCConnector, MSDialect):
# this is only needed if "native ODBC" mode is used,
# which is now disabled by default.
# statement_compiler = MSSQLStrictCompiler
supports_statement_cache = True
execution_ctx_cls = MSExecutionContext_mxodbc
# flag used by _MSNumeric_mxodbc
_need_decimal_fix = True
colspecs = {
sqltypes.Numeric: _MSNumeric_mxodbc,
sqltypes.DateTime: _MSDateTime,
sqltypes.Date: _MSDate_mxodbc,
sqltypes.Time: _MSTime_mxodbc,
VARBINARY: _VARBINARY_mxodbc,
sqltypes.LargeBinary: _VARBINARY_mxodbc,
}
def __init__(self, description_encoding=None, **params):
super(MSDialect_mxodbc, self).__init__(**params)
self.description_encoding = description_encoding
dialect = MSDialect_mxodbc

View File

@@ -0,0 +1,116 @@
from sqlalchemy import inspect
from sqlalchemy import Integer
from ... import create_engine
from ... import exc
from ...schema import Column
from ...schema import DropConstraint
from ...schema import ForeignKeyConstraint
from ...schema import MetaData
from ...schema import Table
from ...testing.provision import create_db
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import get_temp_table_name
from ...testing.provision import log
from ...testing.provision import run_reap_dbs
from ...testing.provision import temp_table_keyword_args
@create_db.for_db("mssql")
def _mssql_create_db(cfg, eng, ident):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
conn.exec_driver_sql("create database %s" % ident)
conn.exec_driver_sql(
"ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident
)
conn.exec_driver_sql(
"ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident
)
conn.exec_driver_sql("use %s" % ident)
conn.exec_driver_sql("create schema test_schema")
conn.exec_driver_sql("create schema test_schema_2")
@drop_db.for_db("mssql")
def _mssql_drop_db(cfg, eng, ident):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
_mssql_drop_ignore(conn, ident)
def _mssql_drop_ignore(conn, ident):
try:
# typically when this happens, we can't KILL the session anyway,
# so let the cleanup process drop the DBs
# for row in conn.exec_driver_sql(
# "select session_id from sys.dm_exec_sessions "
# "where database_id=db_id('%s')" % ident):
# log.info("killing SQL server session %s", row['session_id'])
# conn.exec_driver_sql("kill %s" % row['session_id'])
conn.exec_driver_sql("drop database %s" % ident)
log.info("Reaped db: %s", ident)
return True
except exc.DatabaseError as err:
log.warning("couldn't drop db: %s", err)
return False
@run_reap_dbs.for_db("mssql")
def _reap_mssql_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
log.info("identifiers in file: %s", ", ".join(idents))
to_reap = conn.exec_driver_sql(
"select d.name from sys.databases as d where name "
"like 'TEST_%' and not exists (select session_id "
"from sys.dm_exec_sessions "
"where database_id=d.database_id)"
)
all_names = {dbname.lower() for (dbname,) in to_reap}
to_drop = set()
for name in all_names:
if name in idents:
to_drop.add(name)
dropped = total = 0
for total, dbname in enumerate(to_drop, 1):
if _mssql_drop_ignore(conn, dbname):
dropped += 1
log.info(
"Dropped %d out of %d stale databases detected", dropped, total
)
@temp_table_keyword_args.for_db("mssql")
def _mssql_temp_table_keyword_args(cfg, eng):
return {}
@get_temp_table_name.for_db("mssql")
def _mssql_get_temp_table_name(cfg, eng, base_name):
return "##" + base_name
@drop_all_schema_objects_pre_tables.for_db("mssql")
def drop_all_schema_objects_pre_tables(cfg, eng):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
inspector = inspect(conn)
for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
for tname in inspector.get_table_names(schema=schema):
tb = Table(
tname,
MetaData(),
Column("x", Integer),
Column("y", Integer),
schema=schema,
)
for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
conn.execute(
DropConstraint(
ForeignKeyConstraint(
[tb.c.x], [tb.c.y], name=fk["name"]
)
)
)

View File

@@ -0,0 +1,138 @@
# mssql/pymssql.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mssql+pymssql
:name: pymssql
:dbapi: pymssql
:connectstring: mssql+pymssql://<username>:<password>@<freetds_name>/?charset=utf8
pymssql is a Python module that provides a Python DBAPI interface around
`FreeTDS <https://www.freetds.org/>`_.
.. note::
pymssql is currently not included in SQLAlchemy's continuous integration
(CI) testing.
Modern versions of this driver worked very well with SQL Server and FreeTDS
from Linux and were highly recommended. However, pymssql is currently
unmaintained and has fallen behind the progress of the Microsoft ODBC driver in
its support for newer features of SQL Server. The latest official release of
pymssql at the time of this document is version 2.1.4 (August, 2018) and it
lacks support for:
1. table-valued parameters (TVPs),
2. ``datetimeoffset`` columns using timezone-aware ``datetime`` objects
(values are sent and retrieved as strings), and
3. encrypted connections (e.g., to Azure SQL), when pymssql is installed from
the pre-built wheels. Support for encrypted connections requires building
pymssql from source, which can be a nuisance, especially under Windows.
The above features are all supported by mssql+pyodbc when using Microsoft's
ODBC Driver for SQL Server (msodbcsql), which is now available for Windows,
(several flavors of) Linux, and macOS.
""" # noqa
import re
from .base import MSDialect
from .base import MSIdentifierPreparer
from ... import processors
from ... import types as sqltypes
from ... import util
class _MSNumeric_pymssql(sqltypes.Numeric):
def result_processor(self, dialect, type_):
if not self.asdecimal:
return processors.to_float
else:
return sqltypes.Numeric.result_processor(self, dialect, type_)
class MSIdentifierPreparer_pymssql(MSIdentifierPreparer):
def __init__(self, dialect):
super(MSIdentifierPreparer_pymssql, self).__init__(dialect)
# pymssql has the very unusual behavior that it uses pyformat
# yet does not require that percent signs be doubled
self._double_percents = False
class MSDialect_pymssql(MSDialect):
supports_statement_cache = True
supports_native_decimal = True
driver = "pymssql"
preparer = MSIdentifierPreparer_pymssql
colspecs = util.update_copy(
MSDialect.colspecs,
{sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float},
)
@classmethod
def dbapi(cls):
module = __import__("pymssql")
# pymmsql < 2.1.1 doesn't have a Binary method. we use string
client_ver = tuple(int(x) for x in module.__version__.split("."))
if client_ver < (2, 1, 1):
# TODO: monkeypatching here is less than ideal
module.Binary = lambda x: x if hasattr(x, "decode") else str(x)
if client_ver < (1,):
util.warn(
"The pymssql dialect expects at least "
"the 1.0 series of the pymssql DBAPI."
)
return module
def _get_server_version_info(self, connection):
vers = connection.exec_driver_sql("select @@version").scalar()
m = re.match(r"Microsoft .*? - (\d+)\.(\d+)\.(\d+)\.(\d+)", vers)
if m:
return tuple(int(x) for x in m.group(1, 2, 3, 4))
else:
return None
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
opts.update(url.query)
port = opts.pop("port", None)
if port and "host" in opts:
opts["host"] = "%s:%s" % (opts["host"], port)
return [[], opts]
def is_disconnect(self, e, connection, cursor):
for msg in (
"Adaptive Server connection timed out",
"Net-Lib error during Connection reset by peer",
"message 20003", # connection timeout
"Error 10054",
"Not connected to any MS SQL server",
"Connection is closed",
"message 20006", # Write to the server failed
"message 20017", # Unexpected EOF from the server
"message 20047", # DBPROCESS is dead or not enabled
):
if msg in str(e):
return True
else:
return False
def set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
connection.autocommit(True)
else:
connection.autocommit(False)
super(MSDialect_pymssql, self).set_isolation_level(
connection, level
)
dialect = MSDialect_pymssql

View File

@@ -0,0 +1,673 @@
# mssql/pyodbc.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: mssql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
:url: https://pypi.org/project/pyodbc/
Connecting to PyODBC
--------------------
The URL here is to be translated to PyODBC connection strings, as
detailed in `ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_.
DSN Connections
^^^^^^^^^^^^^^^
A DSN connection in ODBC means that a pre-existing ODBC datasource is
configured on the client machine. The application then specifies the name
of this datasource, which encompasses details such as the specific ODBC driver
in use as well as the network address of the database. Assuming a datasource
is configured on the client, a basic DSN-based connection looks like::
engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
Which above, will pass the following connection string to PyODBC::
DSN=some_dsn;UID=scott;PWD=tiger
If the username and password are omitted, the DSN form will also add
the ``Trusted_Connection=yes`` directive to the ODBC string.
Hostname Connections
^^^^^^^^^^^^^^^^^^^^
Hostname-based connections are also supported by pyodbc. These are often
easier to use than a DSN and have the additional advantage that the specific
database name to connect towards may be specified locally in the URL, rather
than it being fixed as part of a datasource configuration.
When using a hostname connection, the driver name must also be specified in the
query parameters of the URL. As these names usually have spaces in them, the
name must be URL encoded which means using plus signs for spaces::
engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server")
Other keywords interpreted by the Pyodbc dialect to be passed to
``pyodbc.connect()`` in both the DSN and hostname cases include:
``odbc_autotranslate``, ``ansi``, ``unicode_results``, ``autocommit``,
``authentication``.
Note that in order for the dialect to recognize these keywords
(including the ``driver`` keyword above) they must be all lowercase.
Multiple additional keyword arguments must be separated by an
ampersand (``&``), not a semicolon::
engine = create_engine(
"mssql+pyodbc://scott:tiger@myhost:49242/databasename"
"?driver=ODBC+Driver+17+for+SQL+Server"
"&authentication=ActiveDirectoryIntegrated"
)
The equivalent URL can be constructed using :class:`_sa.engine.URL`::
from sqlalchemy.engine import URL
connection_url = URL.create(
"mssql+pyodbc",
username="scott",
password="tiger",
host="myhost",
port=49242,
database="databasename",
query={
"driver": "ODBC Driver 17 for SQL Server",
"authentication": "ActiveDirectoryIntegrated",
},
)
Pass through exact Pyodbc string
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
A PyODBC connection string can also be sent in pyodbc's format directly, as
specified in `the PyODBC documentation
<https://github.com/mkleehammer/pyodbc/wiki/Connecting-to-databases>`_,
using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object
can help make this easier::
from sqlalchemy.engine import URL
connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password"
connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string})
engine = create_engine(connection_url)
.. _mssql_pyodbc_access_tokens:
Connecting to databases with access tokens
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Some database servers are set up to only accept access tokens for login. For
example, SQL Server allows the use of Azure Active Directory tokens to connect
to databases. This requires creating a credential object using the
``azure-identity`` library. More information about the authentication step can be
found in `Microsoft's documentation
<https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=bash>`_.
After getting an engine, the credentials need to be sent to ``pyodbc.connect``
each time a connection is requested. One way to do this is to set up an event
listener on the engine that adds the credential token to the dialect's connect
call. This is discussed more generally in :ref:`engines_dynamic_tokens`. For
SQL Server in particular, this is passed as an ODBC connection attribute with
a data structure `described by Microsoft
<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_.
The following code snippet will create an engine that connects to an Azure SQL
database using Azure credentials::
import struct
from sqlalchemy import create_engine, event
from sqlalchemy.engine.url import URL
from azure import identity
SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
engine = create_engine(connection_string)
azure_credentials = identity.DefaultAzureCredential()
@event.listens_for(engine, "do_connect")
def provide_token(dialect, conn_rec, cargs, cparams):
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
# create token credential
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
# apply it to keyword arguments
cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
.. tip::
The ``Trusted_Connection`` token is currently added by the SQLAlchemy
pyodbc dialect when no username or password is present. This needs
to be removed per Microsoft's
`documentation for Azure access tokens
<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_,
stating that a connection string when using an access token must not contain
``UID``, ``PWD``, ``Authentication`` or ``Trusted_Connection`` parameters.
.. _azure_synapse_ignore_no_transaction_on_rollback:
Avoiding transaction-related exceptions on Azure Synapse Analytics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Azure Synapse Analytics has a significant difference in its transaction
handling compared to plain SQL Server; in some cases an error within a Synapse
transaction can cause it to be arbitrarily terminated on the server side, which
then causes the DBAPI ``.rollback()`` method (as well as ``.commit()``) to
fail. The issue prevents the usual DBAPI contract of allowing ``.rollback()``
to pass silently if no transaction is present as the driver does not expect
this condition. The symptom of this failure is an exception with a message
resembling 'No corresponding transaction found. (111214)' when attempting to
emit a ``.rollback()`` after an operation had a failure of some kind.
This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to
the SQL Server dialect via the :func:`_sa.create_engine` function as follows::
engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True)
Using the above parameter, the dialect will catch ``ProgrammingError``
exceptions raised during ``connection.rollback()`` and emit a warning
if the error message contains code ``111214``, however will not raise
an exception.
.. versionadded:: 1.4.40 Added the
``ignore_no_transaction_on_rollback=True`` parameter.
Enable autocommit for Azure SQL Data Warehouse (DW) connections
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Azure SQL Data Warehouse does not support transactions,
and that can cause problems with SQLAlchemy's "autobegin" (and implicit
commit/rollback) behavior. We can avoid these problems by enabling autocommit
at both the pyodbc and engine levels::
connection_url = sa.engine.URL.create(
"mssql+pyodbc",
username="scott",
password="tiger",
host="dw.azure.example.com",
database="mydb",
query={
"driver": "ODBC Driver 17 for SQL Server",
"autocommit": "True",
},
)
engine = create_engine(connection_url).execution_options(
isolation_level="AUTOCOMMIT"
)
Avoiding sending large string parameters as TEXT/NTEXT
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
By default, for historical reasons, Microsoft's ODBC drivers for SQL Server
send long string parameters (greater than 4000 SBCS characters or 2000 Unicode
characters) as TEXT/NTEXT values. TEXT and NTEXT have been deprecated for many
years and are starting to cause compatibility issues with newer versions of
SQL_Server/Azure. For example, see `this
issue <https://github.com/mkleehammer/pyodbc/issues/835>`_.
Starting with ODBC Driver 18 for SQL Server we can override the legacy
behavior and pass long strings as varchar(max)/nvarchar(max) using the
``LongAsMax=Yes`` connection string parameter::
connection_url = sa.engine.URL.create(
"mssql+pyodbc",
username="scott",
password="tiger",
host="mssqlserver.example.com",
database="mydb",
query={
"driver": "ODBC Driver 18 for SQL Server",
"LongAsMax": "Yes",
},
)
Pyodbc Pooling / connection close behavior
------------------------------------------
PyODBC uses internal `pooling
<https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ by
default, which means connections will be longer lived than they are within
SQLAlchemy itself. As SQLAlchemy has its own pooling behavior, it is often
preferable to disable this behavior. This behavior can only be disabled
globally at the PyODBC module level, **before** any connections are made::
import pyodbc
pyodbc.pooling = False
# don't use the engine before pooling is set to False
engine = create_engine("mssql+pyodbc://user:pass@dsn")
If this variable is left at its default value of ``True``, **the application
will continue to maintain active database connections**, even when the
SQLAlchemy engine itself fully discards a connection or if the engine is
disposed.
.. seealso::
`pooling <https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ -
in the PyODBC documentation.
Driver / Unicode Support
-------------------------
PyODBC works best with Microsoft ODBC drivers, particularly in the area
of Unicode support on both Python 2 and Python 3.
Using the FreeTDS ODBC drivers on Linux or OSX with PyODBC is **not**
recommended; there have been historically many Unicode-related issues
in this area, including before Microsoft offered ODBC drivers for Linux
and OSX. Now that Microsoft offers drivers for all platforms, for
PyODBC support these are recommended. FreeTDS remains relevant for
non-ODBC drivers such as pymssql where it works very well.
Rowcount Support
----------------
Pyodbc only has partial support for rowcount. See the notes at
:ref:`mssql_rowcount_versioning` for important notes when using ORM
versioning.
.. _mssql_pyodbc_fastexecutemany:
Fast Executemany Mode
---------------------
The Pyodbc driver has added support for a "fast executemany" mode of execution
which greatly reduces round trips for a DBAPI ``executemany()`` call when using
Microsoft ODBC drivers, for **limited size batches that fit in memory**. The
feature is enabled by setting the flag ``.fast_executemany`` on the DBAPI
cursor when an executemany call is to be used. The SQLAlchemy pyodbc SQL
Server dialect supports setting this flag automatically when the
``.fast_executemany`` flag is passed to
:func:`_sa.create_engine` ; note that the ODBC driver must be the Microsoft
driver in order to use this flag::
engine = create_engine(
"mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server",
fast_executemany=True)
.. warning:: The pyodbc fast_executemany mode **buffers all rows in memory** and is
not compatible with very large batches of data. A future version of SQLAlchemy
may support this flag as a per-execution option instead.
.. versionadded:: 1.3
.. seealso::
`fast executemany <https://github.com/mkleehammer/pyodbc/wiki/Features-beyond-the-DB-API#fast_executemany>`_
- on github
.. _mssql_pyodbc_setinputsizes:
Setinputsizes Support
-----------------------
The pyodbc ``cursor.setinputsizes()`` method can be used if necessary. To
enable this hook, pass ``use_setinputsizes=True`` to :func:`_sa.create_engine`::
engine = create_engine("mssql+pyodbc://...", use_setinputsizes=True)
The behavior of the hook can then be customized, as may be necessary
particularly if fast_executemany is in use, via the
:meth:`.DialectEvents.do_setinputsizes` hook. See that method for usage
examples.
.. versionchanged:: 1.4.1 The pyodbc dialects will not use setinputsizes
unless ``use_setinputsizes=True`` is passed.
""" # noqa
import datetime
import decimal
import re
import struct
from .base import BINARY
from .base import DATETIMEOFFSET
from .base import MSDialect
from .base import MSExecutionContext
from .base import VARBINARY
from ... import exc
from ... import types as sqltypes
from ... import util
from ...connectors.pyodbc import PyODBCConnector
class _ms_numeric_pyodbc(object):
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
The routines here are needed for older pyodbc versions
as well as current mxODBC versions.
"""
def bind_processor(self, dialect):
super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect)
if not dialect._need_decimal_fix:
return super_process
def process(value):
if self.asdecimal and isinstance(value, decimal.Decimal):
adjusted = value.adjusted()
if adjusted < 0:
return self._small_dec_to_string(value)
elif adjusted > 7:
return self._large_dec_to_string(value)
if super_process:
return super_process(value)
else:
return value
return process
# these routines needed for older versions of pyodbc.
# as of 2.1.8 this logic is integrated.
def _small_dec_to_string(self, value):
return "%s0.%s%s" % (
(value < 0 and "-" or ""),
"0" * (abs(value.adjusted()) - 1),
"".join([str(nint) for nint in value.as_tuple()[1]]),
)
def _large_dec_to_string(self, value):
_int = value.as_tuple()[1]
if "E" in str(value):
result = "%s%s%s" % (
(value < 0 and "-" or ""),
"".join([str(s) for s in _int]),
"0" * (value.adjusted() - (len(_int) - 1)),
)
else:
if (len(_int) - 1) > value.adjusted():
result = "%s%s.%s" % (
(value < 0 and "-" or ""),
"".join([str(s) for s in _int][0 : value.adjusted() + 1]),
"".join([str(s) for s in _int][value.adjusted() + 1 :]),
)
else:
result = "%s%s" % (
(value < 0 and "-" or ""),
"".join([str(s) for s in _int][0 : value.adjusted() + 1]),
)
return result
class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
pass
class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
pass
class _ms_binary_pyodbc(object):
"""Wraps binary values in dialect-specific Binary wrapper.
If the value is null, return a pyodbc-specific BinaryNull
object to prevent pyODBC [and FreeTDS] from defaulting binary
NULL types to SQLWCHAR and causing implicit conversion errors.
"""
def bind_processor(self, dialect):
if dialect.dbapi is None:
return None
DBAPIBinary = dialect.dbapi.Binary
def process(value):
if value is not None:
return DBAPIBinary(value)
else:
# pyodbc-specific
return dialect.dbapi.BinaryNull
return process
class _ODBCDateTimeBindProcessor(object):
"""Add bind processors to handle datetimeoffset behaviors"""
has_tz = False
def bind_processor(self, dialect):
def process(value):
if value is None:
return None
elif isinstance(value, util.string_types):
# if a string was passed directly, allow it through
return value
elif not value.tzinfo or (not self.timezone and not self.has_tz):
# for DateTime(timezone=False)
return value
else:
# for DATETIMEOFFSET or DateTime(timezone=True)
#
# Convert to string format required by T-SQL
dto_string = value.strftime("%Y-%m-%d %H:%M:%S.%f %z")
# offset needs a colon, e.g., -0700 -> -07:00
# "UTC offset in the form (+-)HHMM[SS[.ffffff]]"
# backend currently rejects seconds / fractional seconds
dto_string = re.sub(
r"([\+\-]\d{2})([\d\.]+)$", r"\1:\2", dto_string
)
return dto_string
return process
class _ODBCDateTime(_ODBCDateTimeBindProcessor, sqltypes.DateTime):
pass
class _ODBCDATETIMEOFFSET(_ODBCDateTimeBindProcessor, DATETIMEOFFSET):
has_tz = True
class _VARBINARY_pyodbc(_ms_binary_pyodbc, VARBINARY):
pass
class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY):
pass
class MSExecutionContext_pyodbc(MSExecutionContext):
_embedded_scope_identity = False
def pre_exec(self):
"""where appropriate, issue "select scope_identity()" in the same
statement.
Background on why "scope_identity()" is preferable to "@@identity":
https://msdn.microsoft.com/en-us/library/ms190315.aspx
Background on why we attempt to embed "scope_identity()" into the same
statement as the INSERT:
https://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
"""
super(MSExecutionContext_pyodbc, self).pre_exec()
# don't embed the scope_identity select into an
# "INSERT .. DEFAULT VALUES"
if (
self._select_lastrowid
and self.dialect.use_scope_identity
and len(self.parameters[0])
):
self._embedded_scope_identity = True
self.statement += "; select scope_identity()"
def post_exec(self):
if self._embedded_scope_identity:
# Fetch the last inserted id from the manipulated statement
# We may have to skip over a number of result sets with
# no data (due to triggers, etc.)
while True:
try:
# fetchall() ensures the cursor is consumed
# without closing it (FreeTDS particularly)
row = self.cursor.fetchall()[0]
break
except self.dialect.dbapi.Error:
# no way around this - nextset() consumes the previous set
# so we need to just keep flipping
self.cursor.nextset()
self._lastrowid = int(row[0])
else:
super(MSExecutionContext_pyodbc, self).post_exec()
class MSDialect_pyodbc(PyODBCConnector, MSDialect):
supports_statement_cache = True
# mssql still has problems with this on Linux
supports_sane_rowcount_returning = False
execution_ctx_cls = MSExecutionContext_pyodbc
colspecs = util.update_copy(
MSDialect.colspecs,
{
sqltypes.Numeric: _MSNumeric_pyodbc,
sqltypes.Float: _MSFloat_pyodbc,
BINARY: _BINARY_pyodbc,
# support DateTime(timezone=True)
sqltypes.DateTime: _ODBCDateTime,
DATETIMEOFFSET: _ODBCDATETIMEOFFSET,
# SQL Server dialect has a VARBINARY that is just to support
# "deprecate_large_types" w/ VARBINARY(max), but also we must
# handle the usual SQL standard VARBINARY
VARBINARY: _VARBINARY_pyodbc,
sqltypes.VARBINARY: _VARBINARY_pyodbc,
sqltypes.LargeBinary: _VARBINARY_pyodbc,
},
)
def __init__(
self, description_encoding=None, fast_executemany=False, **params
):
if "description_encoding" in params:
self.description_encoding = params.pop("description_encoding")
super(MSDialect_pyodbc, self).__init__(**params)
self.use_scope_identity = (
self.use_scope_identity
and self.dbapi
and hasattr(self.dbapi.Cursor, "nextset")
)
self._need_decimal_fix = self.dbapi and self._dbapi_version() < (
2,
1,
8,
)
self.fast_executemany = fast_executemany
def _get_server_version_info(self, connection):
try:
# "Version of the instance of SQL Server, in the form
# of 'major.minor.build.revision'"
raw = connection.exec_driver_sql(
"SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)"
).scalar()
except exc.DBAPIError:
# SQL Server docs indicate this function isn't present prior to
# 2008. Before we had the VARCHAR cast above, pyodbc would also
# fail on this query.
return super(MSDialect_pyodbc, self)._get_server_version_info(
connection, allow_chars=False
)
else:
version = []
r = re.compile(r"[.\-]")
for n in r.split(raw):
try:
version.append(int(n))
except ValueError:
pass
return tuple(version)
def on_connect(self):
super_ = super(MSDialect_pyodbc, self).on_connect()
def on_connect(conn):
if super_ is not None:
super_(conn)
self._setup_timestampoffset_type(conn)
return on_connect
def _setup_timestampoffset_type(self, connection):
# output converter function for datetimeoffset
def _handle_datetimeoffset(dto_value):
tup = struct.unpack("<6hI2h", dto_value)
return datetime.datetime(
tup[0],
tup[1],
tup[2],
tup[3],
tup[4],
tup[5],
tup[6] // 1000,
util.timezone(
datetime.timedelta(hours=tup[7], minutes=tup[8])
),
)
odbc_SQL_SS_TIMESTAMPOFFSET = -155 # as defined in SQLNCLI.h
connection.add_output_converter(
odbc_SQL_SS_TIMESTAMPOFFSET, _handle_datetimeoffset
)
def do_executemany(self, cursor, statement, parameters, context=None):
if self.fast_executemany:
cursor.fast_executemany = True
super(MSDialect_pyodbc, self).do_executemany(
cursor, statement, parameters, context=context
)
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
code = e.args[0]
if code in {
"08S01",
"01000",
"01002",
"08003",
"08007",
"08S02",
"08001",
"HYT00",
"HY010",
"10054",
}:
return True
return super(MSDialect_pyodbc, self).is_disconnect(
e, connection, cursor
)
dialect = MSDialect_pyodbc

View File

@@ -0,0 +1,103 @@
# mysql/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from . import base # noqa
from . import cymysql # noqa
from . import mariadbconnector # noqa
from . import mysqlconnector # noqa
from . import mysqldb # noqa
from . import oursql # noqa
from . import pymysql # noqa
from . import pyodbc # noqa
from .base import BIGINT
from .base import BINARY
from .base import BIT
from .base import BLOB
from .base import BOOLEAN
from .base import CHAR
from .base import DATE
from .base import DATETIME
from .base import DECIMAL
from .base import DOUBLE
from .base import ENUM
from .base import FLOAT
from .base import INTEGER
from .base import JSON
from .base import LONGBLOB
from .base import LONGTEXT
from .base import MEDIUMBLOB
from .base import MEDIUMINT
from .base import MEDIUMTEXT
from .base import NCHAR
from .base import NUMERIC
from .base import NVARCHAR
from .base import REAL
from .base import SET
from .base import SMALLINT
from .base import TEXT
from .base import TIME
from .base import TIMESTAMP
from .base import TINYBLOB
from .base import TINYINT
from .base import TINYTEXT
from .base import VARBINARY
from .base import VARCHAR
from .base import YEAR
from .dml import Insert
from .dml import insert
from .expression import match
from ...util import compat
if compat.py3k:
from . import aiomysql # noqa
from . import asyncmy # noqa
# default dialect
base.dialect = dialect = mysqldb.dialect
__all__ = (
"BIGINT",
"BINARY",
"BIT",
"BLOB",
"BOOLEAN",
"CHAR",
"DATE",
"DATETIME",
"DECIMAL",
"DOUBLE",
"ENUM",
"DECIMAL",
"FLOAT",
"INTEGER",
"INTEGER",
"JSON",
"LONGBLOB",
"LONGTEXT",
"MEDIUMBLOB",
"MEDIUMINT",
"MEDIUMTEXT",
"NCHAR",
"NVARCHAR",
"NUMERIC",
"SET",
"SMALLINT",
"REAL",
"TEXT",
"TIME",
"TIMESTAMP",
"TINYBLOB",
"TINYINT",
"TINYTEXT",
"VARBINARY",
"VARCHAR",
"YEAR",
"dialect",
"insert",
"Insert",
"match",
)

View File

@@ -0,0 +1,317 @@
# mysql/aiomysql.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: mysql+aiomysql
:name: aiomysql
:dbapi: aiomysql
:connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
:url: https://github.com/aio-libs/aiomysql
.. warning:: The aiomysql dialect is not currently tested as part of
SQLAlchemys continuous integration. As of September, 2021 the driver
appears to be unmaintained and no longer functions for Python version 3.10,
and additionally depends on a significantly outdated version of PyMySQL.
Please refer to the :ref:`asyncmy` dialect for current MySQL/MariaDB asyncio
functionality.
The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
Using a special asyncio mediation layer, the aiomysql dialect is usable
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
extension package.
This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
""" # noqa
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...engine import AdaptedConnection
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
class AsyncAdapt_aiomysql_cursor:
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor()
# see https://github.com/aio-libs/aiomysql/issues/543
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._execute_mutex:
if parameters is None:
result = await self._cursor.execute(operation)
else:
result = await self._cursor.execute(operation, parameters)
if not self.server_side:
# aiomysql has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._execute_mutex:
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
__slots__ = ()
server_side = True
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor(
adapt_connection.dbapi.aiomysql.SSCursor
)
self._cursor = self.await_(cursor.__aenter__())
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_aiomysql_connection(AdaptedConnection):
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_connection", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
def ping(self, reconnect):
return self.await_(self._connection.ping(reconnect))
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_aiomysql_ss_cursor(self)
else:
return AsyncAdapt_aiomysql_cursor(self)
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def close(self):
# it's not awaitable.
self._connection.close()
class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
__slots__ = ()
await_ = staticmethod(await_fallback)
class AsyncAdapt_aiomysql_dbapi:
def __init__(self, aiomysql, pymysql):
self.aiomysql = aiomysql
self.pymysql = pymysql
self.paramstyle = "format"
self._init_dbapi_attributes()
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
"InterfaceError",
"DataError",
"DatabaseError",
"OperationalError",
"InterfaceError",
"IntegrityError",
"ProgrammingError",
"InternalError",
"NotSupportedError",
):
setattr(self, name, getattr(self.aiomysql, name))
for name in (
"NUMBER",
"STRING",
"DATETIME",
"BINARY",
"TIMESTAMP",
"Binary",
):
setattr(self, name, getattr(self.pymysql, name))
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
if util.asbool(async_fallback):
return AsyncAdaptFallback_aiomysql_connection(
self,
await_fallback(self.aiomysql.connect(*arg, **kw)),
)
else:
return AsyncAdapt_aiomysql_connection(
self,
await_only(self.aiomysql.connect(*arg, **kw)),
)
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
driver = "aiomysql"
supports_statement_cache = True
supports_server_side_cursors = True
_sscursor = AsyncAdapt_aiomysql_ss_cursor
is_async = True
@classmethod
def dbapi(cls):
return AsyncAdapt_aiomysql_dbapi(
__import__("aiomysql"), __import__("pymysql")
)
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def create_connect_args(self, url):
return super(MySQLDialect_aiomysql, self).create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(self, e, connection, cursor):
if super(MySQLDialect_aiomysql, self).is_disconnect(
e, connection, cursor
):
return True
else:
str_e = str(e).lower()
return "not connected" in str_e
def _found_rows_client_flag(self):
from pymysql.constants import CLIENT
return CLIENT.FOUND_ROWS
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_aiomysql

View File

@@ -0,0 +1,328 @@
# mysql/asyncmy.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: mysql+asyncmy
:name: asyncmy
:dbapi: asyncmy
:connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...]
:url: https://github.com/long2ice/asyncmy
.. note:: The asyncmy dialect as of September, 2021 was added to provide
MySQL/MariaDB asyncio compatibility given that the :ref:`aiomysql` database
driver has become unmaintained, however asyncmy is itself very new.
Using a special asyncio mediation layer, the asyncmy dialect is usable
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
extension package.
This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
""" # noqa
from .pymysql import MySQLDialect_pymysql
from ... import pool
from ... import util
from ...engine import AdaptedConnection
from ...util.concurrency import asynccontextmanager
from ...util.concurrency import asyncio
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
class AsyncAdapt_asyncmy_cursor:
server_side = False
__slots__ = (
"_adapt_connection",
"_connection",
"await_",
"_cursor",
"_rows",
)
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor()
self._cursor = self.await_(cursor.__aenter__())
self._rows = []
@property
def description(self):
return self._cursor.description
@property
def rowcount(self):
return self._cursor.rowcount
@property
def arraysize(self):
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value):
self._cursor.arraysize = value
@property
def lastrowid(self):
return self._cursor.lastrowid
def close(self):
# note we aren't actually closing the cursor here,
# we are just letting GC do it. to allow this to be async
# we would need the Result to change how it does "Safe close cursor".
# MySQL "cursors" don't actually have state to be "closed" besides
# exhausting rows, which we already have done for sync cursor.
# another option would be to emulate aiosqlite dialect and assign
# cursor only if we are doing server side cursor operation.
self._rows[:] = []
def execute(self, operation, parameters=None):
return self.await_(self._execute_async(operation, parameters))
def executemany(self, operation, seq_of_parameters):
return self.await_(
self._executemany_async(operation, seq_of_parameters)
)
async def _execute_async(self, operation, parameters):
async with self._adapt_connection._mutex_and_adapt_errors():
if parameters is None:
result = await self._cursor.execute(operation)
else:
result = await self._cursor.execute(operation, parameters)
if not self.server_side:
# asyncmy has a "fake" async result, so we have to pull it out
# of that here since our default result is not async.
# we could just as easily grab "_rows" here and be done with it
# but this is safer.
self._rows = list(await self._cursor.fetchall())
return result
async def _executemany_async(self, operation, seq_of_parameters):
async with self._adapt_connection._mutex_and_adapt_errors():
return await self._cursor.executemany(operation, seq_of_parameters)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
__slots__ = ()
server_side = True
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
cursor = self._connection.cursor(
adapt_connection.dbapi.asyncmy.cursors.SSCursor
)
self._cursor = self.await_(cursor.__aenter__())
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_asyncmy_connection(AdaptedConnection):
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_connection", "_execute_mutex")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
self._execute_mutex = asyncio.Lock()
@asynccontextmanager
async def _mutex_and_adapt_errors(self):
async with self._execute_mutex:
try:
yield
except AttributeError:
raise self.dbapi.InternalError(
"network operation failed due to asyncmy attribute error"
)
def ping(self, reconnect):
assert not reconnect
return self.await_(self._do_ping())
async def _do_ping(self):
async with self._mutex_and_adapt_errors():
return await self._connection.ping(False)
def character_set_name(self):
return self._connection.character_set_name()
def autocommit(self, value):
self.await_(self._connection.autocommit(value))
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_asyncmy_ss_cursor(self)
else:
return AsyncAdapt_asyncmy_cursor(self)
def rollback(self):
self.await_(self._connection.rollback())
def commit(self):
self.await_(self._connection.commit())
def close(self):
# it's not awaitable.
self._connection.close()
class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
__slots__ = ()
await_ = staticmethod(await_fallback)
def _Binary(x):
"""Return x as a binary type."""
return bytes(x)
class AsyncAdapt_asyncmy_dbapi:
def __init__(self, asyncmy):
self.asyncmy = asyncmy
self.paramstyle = "format"
self._init_dbapi_attributes()
def _init_dbapi_attributes(self):
for name in (
"Warning",
"Error",
"InterfaceError",
"DataError",
"DatabaseError",
"OperationalError",
"InterfaceError",
"IntegrityError",
"ProgrammingError",
"InternalError",
"NotSupportedError",
):
setattr(self, name, getattr(self.asyncmy.errors, name))
STRING = util.symbol("STRING")
NUMBER = util.symbol("NUMBER")
BINARY = util.symbol("BINARY")
DATETIME = util.symbol("DATETIME")
TIMESTAMP = util.symbol("TIMESTAMP")
Binary = staticmethod(_Binary)
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
if util.asbool(async_fallback):
return AsyncAdaptFallback_asyncmy_connection(
self,
await_fallback(self.asyncmy.connect(*arg, **kw)),
)
else:
return AsyncAdapt_asyncmy_connection(
self,
await_only(self.asyncmy.connect(*arg, **kw)),
)
class MySQLDialect_asyncmy(MySQLDialect_pymysql):
driver = "asyncmy"
supports_statement_cache = True
supports_server_side_cursors = True
_sscursor = AsyncAdapt_asyncmy_ss_cursor
is_async = True
@classmethod
def dbapi(cls):
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
@classmethod
def get_pool_class(cls, url):
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
return pool.FallbackAsyncAdaptedQueuePool
else:
return pool.AsyncAdaptedQueuePool
def create_connect_args(self, url):
return super(MySQLDialect_asyncmy, self).create_connect_args(
url, _translate_args=dict(username="user", database="db")
)
def is_disconnect(self, e, connection, cursor):
if super(MySQLDialect_asyncmy, self).is_disconnect(
e, connection, cursor
):
return True
else:
str_e = str(e).lower()
return (
"not connected" in str_e or "network operation failed" in str_e
)
def _found_rows_client_flag(self):
from asyncmy.constants import CLIENT
return CLIENT.FOUND_ROWS
def get_driver_connection(self, connection):
return connection._connection
dialect = MySQLDialect_asyncmy

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,82 @@
# mysql/cymysql.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: mysql+cymysql
:name: CyMySQL
:dbapi: cymysql
:connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>]
:url: https://github.com/nakagami/CyMySQL
.. note::
The CyMySQL dialect is **not tested as part of SQLAlchemy's continuous
integration** and may have unresolved issues. The recommended MySQL
dialects are mysqlclient and PyMySQL.
""" # noqa
from .base import BIT
from .base import MySQLDialect
from .mysqldb import MySQLDialect_mysqldb
from ... import util
class _cymysqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""Convert MySQL's 64 bit, variable length binary string to a long."""
def process(value):
if value is not None:
v = 0
for i in util.iterbytes(value):
v = v << 8 | i
return v
return value
return process
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
driver = "cymysql"
supports_statement_cache = True
description_encoding = None
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_unicode_statements = True
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
@classmethod
def dbapi(cls):
return __import__("cymysql")
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.OperationalError):
return self._extract_error_code(e) in (
2006,
2013,
2014,
2045,
2055,
)
elif isinstance(e, self.dbapi.InterfaceError):
# if underlying connection is closed,
# this is the error you get
return True
else:
return False
dialect = MySQLDialect_cymysql

View File

@@ -0,0 +1,175 @@
from ... import exc
from ... import util
from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.base import ColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.expression import alias
from ...util.langhelpers import public_factory
__all__ = ("Insert", "insert")
class Insert(StandardInsert):
"""MySQL-specific implementation of INSERT.
Adds methods for MySQL-specific syntaxes such as ON DUPLICATE KEY UPDATE.
The :class:`~.mysql.Insert` object is created using the
:func:`sqlalchemy.dialects.mysql.insert` function.
.. versionadded:: 1.2
"""
stringify_dialect = "mysql"
inherit_cache = False
@property
def inserted(self):
"""Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE
statement
MySQL's ON DUPLICATE KEY UPDATE clause allows reference to the row
that would be inserted, via a special function called ``VALUES()``.
This attribute provides all columns in this row to be referenceable
such that they will render within a ``VALUES()`` function inside the
ON DUPLICATE KEY UPDATE clause. The attribute is named ``.inserted``
so as not to conflict with the existing
:meth:`_expression.Insert.values` method.
.. tip:: The :attr:`_mysql.Insert.inserted` attribute is an instance
of :class:`_expression.ColumnCollection`, which provides an
interface the same as that of the :attr:`_schema.Table.c`
collection described at :ref:`metadata_tables_and_columns`.
With this collection, ordinary names are accessible like attributes
(e.g. ``stmt.inserted.some_column``), but special names and
dictionary method names should be accessed using indexed access,
such as ``stmt.inserted["column name"]`` or
``stmt.inserted["values"]``. See the docstring for
:class:`_expression.ColumnCollection` for further examples.
.. seealso::
:ref:`mysql_insert_on_duplicate_key_update` - example of how
to use :attr:`_expression.Insert.inserted`
"""
return self.inserted_alias.columns
@util.memoized_property
def inserted_alias(self):
return alias(self.table, name="inserted")
@_generative
@_exclusive_against(
"_post_values_clause",
msgs={
"_post_values_clause": "This Insert construct already "
"has an ON DUPLICATE KEY clause present"
},
)
def on_duplicate_key_update(self, *args, **kw):
r"""
Specifies the ON DUPLICATE KEY UPDATE clause.
:param \**kw: Column keys linked to UPDATE values. The
values may be any SQL expression or supported literal Python
values.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
e.g. those specified using :paramref:`_schema.Column.onupdate`.
These values will not be exercised for an ON DUPLICATE KEY UPDATE
style of UPDATE, unless values are manually specified here.
:param \*args: As an alternative to passing key/value parameters,
a dictionary or list of 2-tuples can be passed as a single positional
argument.
Passing a single dictionary is equivalent to the keyword argument
form::
insert().on_duplicate_key_update({"name": "some name"})
Passing a list of 2-tuples indicates that the parameter assignments
in the UPDATE clause should be ordered as sent, in a manner similar
to that described for the :class:`_expression.Update`
construct overall
in :ref:`tutorial_parameter_ordered_updates`::
insert().on_duplicate_key_update(
[("name", "some name"), ("value", "some value")])
.. versionchanged:: 1.3 parameters can be specified as a dictionary
or list of 2-tuples; the latter form provides for parameter
ordering.
.. versionadded:: 1.2
.. seealso::
:ref:`mysql_insert_on_duplicate_key_update`
"""
if args and kw:
raise exc.ArgumentError(
"Can't pass kwargs and positional arguments simultaneously"
)
if args:
if len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary or list of tuples "
"is accepted positionally."
)
values = args[0]
else:
values = kw
inserted_alias = getattr(self, "inserted_alias", None)
self._post_values_clause = OnDuplicateClause(inserted_alias, values)
insert = public_factory(
Insert, ".dialects.mysql.insert", ".dialects.mysql.Insert"
)
class OnDuplicateClause(ClauseElement):
__visit_name__ = "on_duplicate_key_update"
_parameter_ordering = None
stringify_dialect = "mysql"
def __init__(self, inserted_alias, update):
self.inserted_alias = inserted_alias
# auto-detect that parameters should be ordered. This is copied from
# Update._proces_colparams(), however we don't look for a special flag
# in this case since we are not disambiguating from other use cases as
# we are in Update.values().
if isinstance(update, list) and (
update and isinstance(update[0], tuple)
):
self._parameter_ordering = [key for key, value in update]
update = dict(update)
if isinstance(update, dict):
if not update:
raise ValueError(
"update parameter dictionary must not be empty"
)
elif isinstance(update, ColumnCollection):
update = dict(update)
else:
raise ValueError(
"update parameter must be a non-empty dictionary "
"or a ColumnCollection such as the `.c.` collection "
"of a Table object"
)
self.update = update

View File

@@ -0,0 +1,263 @@
# mysql/enumerated.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import re
from .types import _StringType
from ... import exc
from ... import sql
from ... import util
from ...sql import sqltypes
from ...sql.base import NO_ARG
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
"""MySQL ENUM type."""
__visit_name__ = "ENUM"
native_enum = True
def __init__(self, *enums, **kw):
"""Construct an ENUM.
E.g.::
Column('myenum', ENUM("foo", "bar", "baz"))
:param enums: The range of valid values for this ENUM. Values in
enums are not quoted, they will be escaped and surrounded by single
quotes when generating the schema. This object may also be a
PEP-435-compliant enumerated type.
.. versionadded: 1.1 added support for PEP-435-compliant enumerated
types.
:param strict: This flag has no effect.
.. versionchanged:: The MySQL ENUM type as well as the base Enum
type now validates all Python data values.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
:param quoting: Not used. A warning will be raised if provided.
"""
if kw.pop("quoting", NO_ARG) is not NO_ARG:
util.warn_deprecated_20(
"The 'quoting' parameter to :class:`.mysql.ENUM` is deprecated"
" and will be removed in a future release. "
"This parameter now has no effect."
)
kw.pop("strict", None)
self._enum_init(enums, kw)
_StringType.__init__(self, length=self.length, **kw)
@classmethod
def adapt_emulated_to_native(cls, impl, **kw):
"""Produce a MySQL native :class:`.mysql.ENUM` from plain
:class:`.Enum`.
"""
kw.setdefault("validate_strings", impl.validate_strings)
kw.setdefault("values_callable", impl.values_callable)
kw.setdefault("omit_aliases", impl._omit_aliases)
return cls(**kw)
def _object_value_for_elem(self, elem):
# mysql sends back a blank string for any value that
# was persisted that was not in the enums; that is, it does no
# validation on the incoming data, it "truncates" it to be
# the blank string. Return it straight.
if elem == "":
return elem
else:
return super(ENUM, self)._object_value_for_elem(elem)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
)
class SET(_StringType):
"""MySQL SET type."""
__visit_name__ = "SET"
def __init__(self, *values, **kw):
"""Construct a SET.
E.g.::
Column('myset', SET("foo", "bar", "baz"))
The list of potential values is required in the case that this
set will be used to generate DDL for a table, or if the
:paramref:`.SET.retrieve_as_bitwise` flag is set to True.
:param values: The range of valid values for this SET. The values
are not quoted, they will be escaped and surrounded by single
quotes when generating the schema.
:param convert_unicode: Same flag as that of
:paramref:`.String.convert_unicode`.
:param collation: same as that of :paramref:`.String.collation`
:param charset: same as that of :paramref:`.VARCHAR.charset`.
:param ascii: same as that of :paramref:`.VARCHAR.ascii`.
:param unicode: same as that of :paramref:`.VARCHAR.unicode`.
:param binary: same as that of :paramref:`.VARCHAR.binary`.
:param retrieve_as_bitwise: if True, the data for the set type will be
persisted and selected using an integer value, where a set is coerced
into a bitwise mask for persistence. MySQL allows this mode which
has the advantage of being able to store values unambiguously,
such as the blank string ``''``. The datatype will appear
as the expression ``col + 0`` in a SELECT statement, so that the
value is coerced into an integer value in result sets.
This flag is required if one wishes
to persist a set that can store the blank string ``''`` as a value.
.. warning::
When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
essential that the list of set values is expressed in the
**exact same order** as exists on the MySQL database.
.. versionadded:: 1.0.0
:param quoting: Not used. A warning will be raised if passed.
"""
if kw.pop("quoting", NO_ARG) is not NO_ARG:
util.warn_deprecated_20(
"The 'quoting' parameter to :class:`.mysql.SET` is deprecated"
" and will be removed in a future release. "
"This parameter now has no effect."
)
self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False)
self.values = tuple(values)
if not self.retrieve_as_bitwise and "" in values:
raise exc.ArgumentError(
"Can't use the blank value '' in a SET without "
"setting retrieve_as_bitwise=True"
)
if self.retrieve_as_bitwise:
self._bitmap = dict(
(value, 2 ** idx) for idx, value in enumerate(self.values)
)
self._bitmap.update(
(2 ** idx, value) for idx, value in enumerate(self.values)
)
length = max([len(v) for v in values] + [0])
kw.setdefault("length", length)
super(SET, self).__init__(**kw)
def column_expression(self, colexpr):
if self.retrieve_as_bitwise:
return sql.type_coerce(
sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
)
else:
return colexpr
def result_processor(self, dialect, coltype):
if self.retrieve_as_bitwise:
def process(value):
if value is not None:
value = int(value)
return set(util.map_bits(self._bitmap.__getitem__, value))
else:
return None
else:
super_convert = super(SET, self).result_processor(dialect, coltype)
def process(value):
if isinstance(value, util.string_types):
# MySQLdb returns a string, let's parse
if super_convert:
value = super_convert(value)
return set(re.findall(r"[^,]+", value))
else:
# mysql-connector-python does a naive
# split(",") which throws in an empty string
if value is not None:
value.discard("")
return value
return process
def bind_processor(self, dialect):
super_convert = super(SET, self).bind_processor(dialect)
if self.retrieve_as_bitwise:
def process(value):
if value is None:
return None
elif isinstance(value, util.int_types + util.string_types):
if super_convert:
return super_convert(value)
else:
return value
else:
int_value = 0
for v in value:
int_value |= self._bitmap[v]
return int_value
else:
def process(value):
# accept strings and int (actually bitflag) values directly
if value is not None and not isinstance(
value, util.int_types + util.string_types
):
value = ",".join(value)
if super_convert:
return super_convert(value)
else:
return value
return process
def adapt(self, impltype, **kw):
kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
return util.constructor_copy(self, impltype, *self.values, **kw)
def __repr__(self):
return util.generic_repr(
self,
to_inspect=[SET, _StringType],
additional_kw=[
("retrieve_as_bitwise", False),
],
)

View File

@@ -0,0 +1,130 @@
from ... import exc
from ... import util
from ...sql import coercions
from ...sql import elements
from ...sql import operators
from ...sql import roles
from ...sql.base import _generative
from ...sql.base import Generative
class match(Generative, elements.BinaryExpression):
"""Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
E.g.::
from sqlalchemy import desc
from sqlalchemy.dialects.mysql import match
match_expr = match(
users_table.c.firstname,
users_table.c.lastname,
against="Firstname Lastname",
)
stmt = (
select(users_table)
.where(match_expr.in_boolean_mode())
.order_by(desc(match_expr))
)
Would produce SQL resembling::
SELECT id, firstname, lastname
FROM user
WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE)
ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC
The :func:`_mysql.match` function is a standalone version of the
:meth:`_sql.ColumnElement.match` method available on all
SQL expressions, as when :meth:`_expression.ColumnElement.match` is
used, but allows to pass multiple columns
:param cols: column expressions to match against
:param against: expression to be compared towards
:param in_boolean_mode: boolean, set "boolean mode" to true
:param in_natural_language_mode: boolean , set "natural language" to true
:param with_query_expansion: boolean, set "query expansion" to true
.. versionadded:: 1.4.19
.. seealso::
:meth:`_expression.ColumnElement.match`
"""
__visit_name__ = "mysql_match"
inherit_cache = True
def __init__(self, *cols, **kw):
if not cols:
raise exc.ArgumentError("columns are required")
against = kw.pop("against", None)
if against is None:
raise exc.ArgumentError("against is required")
against = coercions.expect(
roles.ExpressionElementRole,
against,
)
left = elements.BooleanClauseList._construct_raw(
operators.comma_op,
clauses=cols,
)
left.group = False
flags = util.immutabledict(
{
"mysql_boolean_mode": kw.pop("in_boolean_mode", False),
"mysql_natural_language": kw.pop(
"in_natural_language_mode", False
),
"mysql_query_expansion": kw.pop("with_query_expansion", False),
}
)
if kw:
raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw)))
super(match, self).__init__(
left, against, operators.match_op, modifiers=flags
)
@_generative
def in_boolean_mode(self):
"""Apply the "IN BOOLEAN MODE" modifier to the MATCH expression.
:return: a new :class:`_mysql.match` instance with modifications
applied.
"""
self.modifiers = self.modifiers.union({"mysql_boolean_mode": True})
@_generative
def in_natural_language_mode(self):
"""Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH
expression.
:return: a new :class:`_mysql.match` instance with modifications
applied.
"""
self.modifiers = self.modifiers.union({"mysql_natural_language": True})
@_generative
def with_query_expansion(self):
"""Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression.
:return: a new :class:`_mysql.match` instance with modifications
applied.
"""
self.modifiers = self.modifiers.union({"mysql_query_expansion": True})

View File

@@ -0,0 +1,84 @@
# mysql/json.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import absolute_import
from ... import types as sqltypes
class JSON(sqltypes.JSON):
"""MySQL JSON type.
MySQL supports JSON as of version 5.7.
MariaDB supports JSON (as an alias for LONGTEXT) as of version 10.2.
:class:`_mysql.JSON` is used automatically whenever the base
:class:`_types.JSON` datatype is used against a MySQL or MariaDB backend.
.. seealso::
:class:`_types.JSON` - main documentation for the generic
cross-platform JSON datatype.
The :class:`.mysql.JSON` type supports persistence of JSON values
as well as the core index operations provided by :class:`_types.JSON`
datatype, by adapting the operations to render the ``JSON_EXTRACT``
function at the database level.
.. versionadded:: 1.1
"""
pass
class _FormatTypeMixin(object):
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
else:
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
"".join(
[
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
for elem in value
]
)
)

View File

@@ -0,0 +1,25 @@
from .base import MariaDBIdentifierPreparer
from .base import MySQLDialect
class MariaDBDialect(MySQLDialect):
is_mariadb = True
supports_statement_cache = True
name = "mariadb"
preparer = MariaDBIdentifierPreparer
def loader(driver):
driver_mod = __import__(
"sqlalchemy.dialects.mysql.%s" % driver
).dialects.mysql
driver_cls = getattr(driver_mod, driver).dialect
return type(
"MariaDBDialect_%s" % driver,
(
MariaDBDialect,
driver_cls,
),
{"supports_statement_cache": True},
)

View File

@@ -0,0 +1,240 @@
# mysql/mariadbconnector.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+mariadbconnector
:name: MariaDB Connector/Python
:dbapi: mariadb
:connectstring: mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://pypi.org/project/mariadb/
Driver Status
-------------
MariaDB Connector/Python enables Python programs to access MariaDB and MySQL
databases using an API which is compliant with the Python DB API 2.0 (PEP-249).
It is written in C and uses MariaDB Connector/C client library for client server
communication.
Note that the default driver for a ``mariadb://`` connection URI continues to
be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
""" # noqa
import re
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from ... import sql
from ... import util
mariadb_cpy_minimum_version = (1, 0, 1)
class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
_lastrowid = None
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(buffered=False)
def create_default_cursor(self):
return self._dbapi_connection.cursor(buffered=True)
def post_exec(self):
if self.isinsert and self.compiled.postfetch_lastrowid:
self._lastrowid = self.cursor.lastrowid
def get_lastrowid(self):
return self._lastrowid
class MySQLCompiler_mariadbconnector(MySQLCompiler):
pass
class MySQLDialect_mariadbconnector(MySQLDialect):
driver = "mariadbconnector"
supports_statement_cache = True
# set this to True at the module level to prevent the driver from running
# against a backend that server detects as MySQL. currently this appears to
# be unnecessary as MariaDB client libraries have always worked against
# MySQL databases. However, if this changes at some point, this can be
# adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if
# this change is made at some point to ensure the correct exception
# is raised at the correct point when running the driver against
# a MySQL backend.
# is_mariadb = True
supports_unicode_statements = True
encoding = "utf8mb4"
convert_unicode = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = "qmark"
execution_ctx_cls = MySQLExecutionContext_mariadbconnector
statement_compiler = MySQLCompiler_mariadbconnector
supports_server_side_cursors = True
@util.memoized_property
def _dbapi_version(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
return tuple(
[
int(x)
for x in re.findall(
r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
)
]
)
else:
return (99, 99, 99)
def __init__(self, **kwargs):
super(MySQLDialect_mariadbconnector, self).__init__(**kwargs)
self.paramstyle = "qmark"
if self.dbapi is not None:
if self._dbapi_version < mariadb_cpy_minimum_version:
raise NotImplementedError(
"The minimum required version for MariaDB "
"Connector/Python is %s"
% ".".join(str(x) for x in mariadb_cpy_minimum_version)
)
@classmethod
def dbapi(cls):
return __import__("mariadb")
def is_disconnect(self, e, connection, cursor):
if super(MySQLDialect_mariadbconnector, self).is_disconnect(
e, connection, cursor
):
return True
elif isinstance(e, self.dbapi.Error):
str_e = str(e).lower()
return "not connected" in str_e or "isn't valid" in str_e
else:
return False
def create_connect_args(self, url):
opts = url.translate_connect_args()
int_params = [
"connect_timeout",
"read_timeout",
"write_timeout",
"client_flag",
"port",
"pool_size",
]
bool_params = [
"local_infile",
"ssl_verify_cert",
"ssl",
"pool_reset_connection",
]
for key in int_params:
util.coerce_kw_type(opts, key, int)
for key in bool_params:
util.coerce_kw_type(opts, key, bool)
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get("client_flag", 0)
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
self.dbapi.__name__ + ".constants.CLIENT"
).constants.CLIENT
client_flag |= CLIENT_FLAGS.FOUND_ROWS
except (AttributeError, ImportError):
self.supports_sane_rowcount = False
opts["client_flag"] = client_flag
return [[], opts]
def _extract_error_code(self, exception):
try:
rc = exception.errno
except:
rc = -1
return rc
def _detect_charset(self, connection):
return "utf8mb4"
_isolation_lookup = set(
[
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
]
)
def _set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
connection.autocommit = True
else:
connection.autocommit = False
super(MySQLDialect_mariadbconnector, self)._set_isolation_level(
connection, level
)
def do_begin_twophase(self, connection, xid):
connection.execute(
sql.text("XA BEGIN :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
def do_prepare_twophase(self, connection, xid):
connection.execute(
sql.text("XA END :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
connection.execute(
sql.text("XA PREPARE :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
connection.execute(
sql.text("XA END :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
connection.execute(
sql.text("XA ROLLBACK :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
connection.execute(
sql.text("XA COMMIT :xid").bindparams(
sql.bindparam("xid", xid, literal_execute=True)
)
)
dialect = MySQLDialect_mariadbconnector

View File

@@ -0,0 +1,240 @@
# mysql/mysqlconnector.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: mysql+mysqlconnector
:name: MySQL Connector/Python
:dbapi: myconnpy
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://pypi.org/project/mysql-connector-python/
.. note::
The MySQL Connector/Python DBAPI has had many issues since its release,
some of which may remain unresolved, and the mysqlconnector dialect is
**not tested as part of SQLAlchemy's continuous integration**.
The recommended MySQL dialects are mysqlclient and PyMySQL.
""" # noqa
import re
from .base import BIT
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLIdentifierPreparer
from ... import processors
from ... import util
class MySQLCompiler_mysqlconnector(MySQLCompiler):
def visit_mod_binary(self, binary, operator, **kw):
if self.dialect._mysqlconnector_double_percents:
return (
self.process(binary.left, **kw)
+ " %% "
+ self.process(binary.right, **kw)
)
else:
return (
self.process(binary.left, **kw)
+ " % "
+ self.process(binary.right, **kw)
)
def post_process_text(self, text):
if self.dialect._mysqlconnector_double_percents:
return text.replace("%", "%%")
else:
return text
def escape_literal_column(self, text):
if self.dialect._mysqlconnector_double_percents:
return text.replace("%", "%%")
else:
return text
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
@property
def _double_percents(self):
return self.dialect._mysqlconnector_double_percents
@_double_percents.setter
def _double_percents(self, value):
pass
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
if self.dialect._mysqlconnector_double_percents:
return value.replace("%", "%%")
else:
return value
class _myconnpyBIT(BIT):
def result_processor(self, dialect, coltype):
"""MySQL-connector already converts mysql bits, so."""
return None
class MySQLDialect_mysqlconnector(MySQLDialect):
driver = "mysqlconnector"
supports_statement_cache = True
supports_unicode_binds = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = "format"
statement_compiler = MySQLCompiler_mysqlconnector
preparer = MySQLIdentifierPreparer_mysqlconnector
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
def __init__(self, *arg, **kw):
super(MySQLDialect_mysqlconnector, self).__init__(*arg, **kw)
# hack description encoding since mysqlconnector randomly
# returns bytes or not
self._description_decoder = (
processors.to_conditional_unicode_processor_factory
)(self.description_encoding)
def _check_unicode_description(self, connection):
# hack description encoding since mysqlconnector randomly
# returns bytes or not
return False
@property
def description_encoding(self):
# total guess
return "latin-1"
@util.memoized_property
def supports_unicode_statements(self):
return util.py3k or self._mysqlconnector_version_info > (2, 0)
@classmethod
def dbapi(cls):
from mysql import connector
return connector
def do_ping(self, dbapi_connection):
try:
dbapi_connection.ping(False)
except self.dbapi.Error as err:
if self.is_disconnect(err, dbapi_connection, None):
return False
else:
raise
else:
return True
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
opts.update(url.query)
util.coerce_kw_type(opts, "allow_local_infile", bool)
util.coerce_kw_type(opts, "autocommit", bool)
util.coerce_kw_type(opts, "buffered", bool)
util.coerce_kw_type(opts, "compress", bool)
util.coerce_kw_type(opts, "connection_timeout", int)
util.coerce_kw_type(opts, "connect_timeout", int)
util.coerce_kw_type(opts, "consume_results", bool)
util.coerce_kw_type(opts, "force_ipv6", bool)
util.coerce_kw_type(opts, "get_warnings", bool)
util.coerce_kw_type(opts, "pool_reset_session", bool)
util.coerce_kw_type(opts, "pool_size", int)
util.coerce_kw_type(opts, "raise_on_warnings", bool)
util.coerce_kw_type(opts, "raw", bool)
util.coerce_kw_type(opts, "ssl_verify_cert", bool)
util.coerce_kw_type(opts, "use_pure", bool)
util.coerce_kw_type(opts, "use_unicode", bool)
# unfortunately, MySQL/connector python refuses to release a
# cursor without reading fully, so non-buffered isn't an option
opts.setdefault("buffered", True)
# FOUND_ROWS must be set in ClientFlag to enable
# supports_sane_rowcount.
if self.dbapi is not None:
try:
from mysql.connector.constants import ClientFlag
client_flags = opts.get(
"client_flags", ClientFlag.get_default()
)
client_flags |= ClientFlag.FOUND_ROWS
opts["client_flags"] = client_flags
except Exception:
pass
return [[], opts]
@util.memoized_property
def _mysqlconnector_version_info(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
@util.memoized_property
def _mysqlconnector_double_percents(self):
return not util.py3k and self._mysqlconnector_version_info < (2, 0)
def _detect_charset(self, connection):
return connection.connection.charset
def _extract_error_code(self, exception):
return exception.errno
def is_disconnect(self, e, connection, cursor):
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
if isinstance(e, exceptions):
return (
e.errno in errnos
or "MySQL Connection not available." in str(e)
or "Connection to MySQL is not available" in str(e)
)
else:
return False
def _compat_fetchall(self, rp, charset=None):
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
return rp.fetchone()
_isolation_lookup = set(
[
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
]
)
def _set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
connection.autocommit = True
else:
connection.autocommit = False
super(MySQLDialect_mysqlconnector, self)._set_isolation_level(
connection, level
)
dialect = MySQLDialect_mysqlconnector

View File

@@ -0,0 +1,331 @@
# mysql/mysqldb.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+mysqldb
:name: mysqlclient (maintained fork of MySQL-Python)
:dbapi: mysqldb
:connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://pypi.org/project/mysqlclient/
Driver Status
-------------
The mysqlclient DBAPI is a maintained fork of the
`MySQL-Python <https://sourceforge.net/projects/mysql-python>`_ DBAPI
that is no longer maintained. `mysqlclient`_ supports Python 2 and Python 3
and is very stable.
.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python
.. _mysqldb_unicode:
Unicode
-------
Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
.. _mysqldb_ssl:
SSL Connections
----------------
The mysqlclient and PyMySQL DBAPIs accept an additional dictionary under the
key "ssl", which may be specified using the
:paramref:`_sa.create_engine.connect_args` dictionary::
engine = create_engine(
"mysql+mysqldb://scott:tiger@192.168.0.134/test",
connect_args={
"ssl": {
"ssl_ca": "/home/gord/client-ssl/ca.pem",
"ssl_cert": "/home/gord/client-ssl/client-cert.pem",
"ssl_key": "/home/gord/client-ssl/client-key.pem"
}
}
)
For convenience, the following keys may also be specified inline within the URL
where they will be interpreted into the "ssl" dictionary automatically:
"ssl_ca", "ssl_cert", "ssl_key", "ssl_capath", "ssl_cipher",
"ssl_check_hostname". An example is as follows::
connection_uri = (
"mysql+mysqldb://scott:tiger@192.168.0.134/test"
"?ssl_ca=/home/gord/client-ssl/ca.pem"
"&ssl_cert=/home/gord/client-ssl/client-cert.pem"
"&ssl_key=/home/gord/client-ssl/client-key.pem"
)
If the server uses an automatically-generated certificate that is self-signed
or does not match the host name (as seen from the client), it may also be
necessary to indicate ``ssl_check_hostname=false``::
connection_uri = (
"mysql+pymysql://scott:tiger@192.168.0.134/test"
"?ssl_ca=/home/gord/client-ssl/ca.pem"
"&ssl_cert=/home/gord/client-ssl/client-cert.pem"
"&ssl_key=/home/gord/client-ssl/client-key.pem"
"&ssl_check_hostname=false"
)
.. seealso::
:ref:`pymysql_ssl` in the PyMySQL dialect
Using MySQLdb with Google Cloud SQL
-----------------------------------
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
using a URL like the following::
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
Server Side Cursors
-------------------
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
"""
import re
from .base import MySQLCompiler
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .base import MySQLIdentifierPreparer
from .base import TEXT
from ... import sql
from ... import util
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
@property
def rowcount(self):
if hasattr(self, "_rowcount"):
return self._rowcount
else:
return self.cursor.rowcount
class MySQLCompiler_mysqldb(MySQLCompiler):
pass
class MySQLDialect_mysqldb(MySQLDialect):
driver = "mysqldb"
supports_statement_cache = True
supports_unicode_statements = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
supports_native_decimal = True
default_paramstyle = "format"
execution_ctx_cls = MySQLExecutionContext_mysqldb
statement_compiler = MySQLCompiler_mysqldb
preparer = MySQLIdentifierPreparer
def __init__(self, **kwargs):
super(MySQLDialect_mysqldb, self).__init__(**kwargs)
self._mysql_dbapi_version = (
self._parse_dbapi_version(self.dbapi.__version__)
if self.dbapi is not None and hasattr(self.dbapi, "__version__")
else (0, 0, 0)
)
def _parse_dbapi_version(self, version):
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
if m:
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
else:
return (0, 0, 0)
@util.langhelpers.memoized_property
def supports_server_side_cursors(self):
try:
cursors = __import__("MySQLdb.cursors").cursors
self._sscursor = cursors.SSCursor
return True
except (ImportError, AttributeError):
return False
@classmethod
def dbapi(cls):
return __import__("MySQLdb")
def on_connect(self):
super_ = super(MySQLDialect_mysqldb, self).on_connect()
def on_connect(conn):
if super_ is not None:
super_(conn)
charset_name = conn.character_set_name()
if charset_name is not None:
cursor = conn.cursor()
cursor.execute("SET NAMES %s" % charset_name)
cursor.close()
return on_connect
def do_ping(self, dbapi_connection):
try:
dbapi_connection.ping(False)
except self.dbapi.Error as err:
if self.is_disconnect(err, dbapi_connection, None):
return False
else:
raise
else:
return True
def do_executemany(self, cursor, statement, parameters, context=None):
rowcount = cursor.executemany(statement, parameters)
if context is not None:
context._rowcount = rowcount
def _check_unicode_returns(self, connection):
# work around issue fixed in
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
# specific issue w/ the utf8mb4_bin collation and unicode returns
collation = connection.exec_driver_sql(
"show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
% (
self.identifier_preparer.quote("Charset"),
self.identifier_preparer.quote("Collation"),
)
).scalar()
has_utf8mb4_bin = self.server_version_info > (5,) and collation
if has_utf8mb4_bin:
additional_tests = [
sql.collate(
sql.cast(
sql.literal_column("'test collated returns'"),
TEXT(charset="utf8mb4"),
),
"utf8mb4_bin",
)
]
else:
additional_tests = []
return super(MySQLDialect_mysqldb, self)._check_unicode_returns(
connection, additional_tests
)
def create_connect_args(self, url, _translate_args=None):
if _translate_args is None:
_translate_args = dict(
database="db", username="user", password="passwd"
)
opts = url.translate_connect_args(**_translate_args)
opts.update(url.query)
util.coerce_kw_type(opts, "compress", bool)
util.coerce_kw_type(opts, "connect_timeout", int)
util.coerce_kw_type(opts, "read_timeout", int)
util.coerce_kw_type(opts, "write_timeout", int)
util.coerce_kw_type(opts, "client_flag", int)
util.coerce_kw_type(opts, "local_infile", int)
# Note: using either of the below will cause all strings to be
# returned as Unicode, both in raw SQL operations and with column
# types like String and MSString.
util.coerce_kw_type(opts, "use_unicode", bool)
util.coerce_kw_type(opts, "charset", str)
# Rich values 'cursorclass' and 'conv' are not supported via
# query string.
ssl = {}
keys = [
("ssl_ca", str),
("ssl_key", str),
("ssl_cert", str),
("ssl_capath", str),
("ssl_cipher", str),
("ssl_check_hostname", bool),
]
for key, kw_type in keys:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], kw_type)
del opts[key]
if ssl:
opts["ssl"] = ssl
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
client_flag = opts.get("client_flag", 0)
client_flag_found_rows = self._found_rows_client_flag()
if client_flag_found_rows is not None:
client_flag |= client_flag_found_rows
opts["client_flag"] = client_flag
return [[], opts]
def _found_rows_client_flag(self):
if self.dbapi is not None:
try:
CLIENT_FLAGS = __import__(
self.dbapi.__name__ + ".constants.CLIENT"
).constants.CLIENT
except (AttributeError, ImportError):
return None
else:
return CLIENT_FLAGS.FOUND_ROWS
else:
return None
def _extract_error_code(self, exception):
return exception.args[0]
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
try:
# note: the SQL here would be
# "SHOW VARIABLES LIKE 'character_set%%'"
cset_name = connection.connection.character_set_name
except AttributeError:
util.warn(
"No 'character_set_name' can be detected with "
"this MySQL-Python version; "
"please upgrade to a recent version of MySQL-Python. "
"Assuming latin1."
)
return "latin1"
else:
return cset_name()
_isolation_lookup = set(
[
"SERIALIZABLE",
"READ UNCOMMITTED",
"READ COMMITTED",
"REPEATABLE READ",
"AUTOCOMMIT",
]
)
def _set_isolation_level(self, connection, level):
if level == "AUTOCOMMIT":
connection.autocommit(True)
else:
connection.autocommit(False)
super(MySQLDialect_mysqldb, self)._set_isolation_level(
connection, level
)
dialect = MySQLDialect_mysqldb

View File

@@ -0,0 +1,273 @@
# mysql/oursql.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: mysql+oursql
:name: OurSQL
:dbapi: oursql
:connectstring: mysql+oursql://<user>:<password>@<host>[:<port>]/<dbname>
:url: https://packages.python.org/oursql/
.. note::
The OurSQL MySQL dialect is legacy and is no longer supported upstream,
and is **not tested as part of SQLAlchemy's continuous integration**.
The recommended MySQL dialects are mysqlclient and PyMySQL.
.. deprecated:: 1.4 The OurSQL DBAPI is deprecated and will be removed
in a future version. Please use one of the supported DBAPIs to
connect to mysql.
Unicode
-------
Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
"""
from .base import BIT
from .base import MySQLDialect
from .base import MySQLExecutionContext
from ... import types as sqltypes
from ... import util
class _oursqlBIT(BIT):
def result_processor(self, dialect, coltype):
"""oursql already converts mysql bits, so."""
return None
class MySQLExecutionContext_oursql(MySQLExecutionContext):
@property
def plain_query(self):
return self.execution_options.get("_oursql_plain_query", False)
class MySQLDialect_oursql(MySQLDialect):
driver = "oursql"
supports_statement_cache = True
if util.py2k:
supports_unicode_binds = True
supports_unicode_statements = True
supports_native_decimal = True
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
execution_ctx_cls = MySQLExecutionContext_oursql
colspecs = util.update_copy(
MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _oursqlBIT}
)
@classmethod
def dbapi(cls):
util.warn_deprecated(
"The OurSQL DBAPI is deprecated and will be removed "
"in a future version. Please use one of the supported DBAPIs to "
"connect to mysql.",
version="1.4",
)
return __import__("oursql")
def do_execute(self, cursor, statement, parameters, context=None):
"""Provide an implementation of
*cursor.execute(statement, parameters)*."""
if context and context.plain_query:
cursor.execute(statement, plain_query=True)
else:
cursor.execute(statement, parameters)
def do_begin(self, connection):
connection.cursor().execute("BEGIN", plain_query=True)
def _xa_query(self, connection, query, xid):
if util.py2k:
arg = connection.connection._escape_string(xid)
else:
charset = self._connection_charset
arg = connection.connection._escape_string(
xid.encode(charset)
).decode(charset)
arg = "'%s'" % arg
connection.execution_options(_oursql_plain_query=True).exec_driver_sql(
query % arg
)
# Because mysql is bad, these methods have to be
# reimplemented to use _PlainQuery. Basically, some queries
# refuse to return any data if they're run through
# the parameterized query API, or refuse to be parameterized
# in the first place.
def do_begin_twophase(self, connection, xid):
self._xa_query(connection, "XA BEGIN %s", xid)
def do_prepare_twophase(self, connection, xid):
self._xa_query(connection, "XA END %s", xid)
self._xa_query(connection, "XA PREPARE %s", xid)
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
self._xa_query(connection, "XA END %s", xid)
self._xa_query(connection, "XA ROLLBACK %s", xid)
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
if not is_prepared:
self.do_prepare_twophase(connection, xid)
self._xa_query(connection, "XA COMMIT %s", xid)
# Q: why didn't we need all these "plain_query" overrides earlier ?
# am i on a newer/older version of OurSQL ?
def has_table(self, connection, table_name, schema=None):
return MySQLDialect.has_table(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema,
)
def get_table_options(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_table_options(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema=schema,
**kw
)
def get_columns(self, connection, table_name, schema=None, **kw):
return MySQLDialect.get_columns(
self,
connection.connect().execution_options(_oursql_plain_query=True),
table_name,
schema=schema,
**kw
)
def get_view_names(self, connection, schema=None, **kw):
return MySQLDialect.get_view_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
schema=schema,
**kw
)
def get_table_names(self, connection, schema=None, **kw):
return MySQLDialect.get_table_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
schema,
)
def get_schema_names(self, connection, **kw):
return MySQLDialect.get_schema_names(
self,
connection.connect().execution_options(_oursql_plain_query=True),
**kw
)
def initialize(self, connection):
return MySQLDialect.initialize(
self, connection.execution_options(_oursql_plain_query=True)
)
def _show_create_table(
self, connection, table, charset=None, full_name=None
):
return MySQLDialect._show_create_table(
self,
connection.connect(close_with_result=True).execution_options(
_oursql_plain_query=True
),
table,
charset,
full_name,
)
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.ProgrammingError):
return (
e.errno is None
and "cursor" not in e.args[1]
and e.args[1].endswith("closed")
)
else:
return e.errno in (2006, 2013, 2014, 2045, 2055)
def create_connect_args(self, url):
opts = url.translate_connect_args(
database="db", username="user", password="passwd"
)
opts.update(url.query)
util.coerce_kw_type(opts, "port", int)
util.coerce_kw_type(opts, "compress", bool)
util.coerce_kw_type(opts, "autoping", bool)
util.coerce_kw_type(opts, "raise_on_warnings", bool)
util.coerce_kw_type(opts, "default_charset", bool)
if opts.pop("default_charset", False):
opts["charset"] = None
else:
util.coerce_kw_type(opts, "charset", str)
opts["use_unicode"] = opts.get("use_unicode", True)
util.coerce_kw_type(opts, "use_unicode", bool)
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
# supports_sane_rowcount.
opts.setdefault("found_rows", True)
ssl = {}
for key in [
"ssl_ca",
"ssl_key",
"ssl_cert",
"ssl_capath",
"ssl_cipher",
]:
if key in opts:
ssl[key[4:]] = opts[key]
util.coerce_kw_type(ssl, key[4:], str)
del opts[key]
if ssl:
opts["ssl"] = ssl
return [[], opts]
def _extract_error_code(self, exception):
return exception.errno
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
return connection.connection.charset
def _compat_fetchall(self, rp, charset=None):
"""oursql isn't super-broken like MySQLdb, yaaay."""
return rp.fetchall()
def _compat_fetchone(self, rp, charset=None):
"""oursql isn't super-broken like MySQLdb, yaaay."""
return rp.fetchone()
def _compat_first(self, rp, charset=None):
return rp.first()
dialect = MySQLDialect_oursql

View File

@@ -0,0 +1,78 @@
from ... import exc
from ...testing.provision import configure_follower
from ...testing.provision import create_db
from ...testing.provision import drop_db
from ...testing.provision import generate_driver_url
from ...testing.provision import temp_table_keyword_args
@generate_driver_url.for_db("mysql", "mariadb")
def generate_driver_url(url, driver, query_str):
backend = url.get_backend_name()
# NOTE: at the moment, tests are running mariadbconnector
# against both mariadb and mysql backends. if we want this to be
# limited, do the decision making here to reject a "mysql+mariadbconnector"
# URL. Optionally also re-enable the module level
# MySQLDialect_mariadbconnector.is_mysql flag as well, which must include
# a unit and/or functional test.
# all the Jenkins tests have been running mysqlclient Python library
# built against mariadb client drivers for years against all MySQL /
# MariaDB versions going back to MySQL 5.6, currently they can talk
# to MySQL databases without problems.
if backend == "mysql":
dialect_cls = url.get_dialect()
if dialect_cls._is_mariadb_from_url(url):
backend = "mariadb"
new_url = url.set(
drivername="%s+%s" % (backend, driver)
).update_query_string(query_str)
try:
new_url.get_dialect()
except exc.NoSuchModuleError:
return None
else:
return new_url
@create_db.for_db("mysql", "mariadb")
def _mysql_create_db(cfg, eng, ident):
with eng.begin() as conn:
try:
_mysql_drop_db(cfg, conn, ident)
except Exception:
pass
with eng.begin() as conn:
conn.exec_driver_sql(
"CREATE DATABASE %s CHARACTER SET utf8mb4" % ident
)
conn.exec_driver_sql(
"CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident
)
conn.exec_driver_sql(
"CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident
)
@configure_follower.for_db("mysql", "mariadb")
def _mysql_configure_follower(config, ident):
config.test_schema = "%s_test_schema" % ident
config.test_schema_2 = "%s_test_schema_2" % ident
@drop_db.for_db("mysql", "mariadb")
def _mysql_drop_db(cfg, eng, ident):
with eng.begin() as conn:
conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident)
conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident)
conn.exec_driver_sql("DROP DATABASE %s" % ident)
@temp_table_keyword_args.for_db("mysql", "mariadb")
def _mysql_temp_table_keyword_args(cfg, eng):
return {"prefixes": ["TEMPORARY"]}

View File

@@ -0,0 +1,98 @@
# mysql/pymysql.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: mysql+pymysql
:name: PyMySQL
:dbapi: pymysql
:connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>]
:url: https://pymysql.readthedocs.io/
Unicode
-------
Please see :ref:`mysql_unicode` for current recommendations on unicode
handling.
.. _pymysql_ssl:
SSL Connections
------------------
The PyMySQL DBAPI accepts the same SSL arguments as that of MySQLdb,
described at :ref:`mysqldb_ssl`. See that section for examples.
MySQL-Python Compatibility
--------------------------
The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
and targets 100% compatibility. Most behavioral notes for MySQL-python apply
to the pymysql driver as well.
""" # noqa
from .mysqldb import MySQLDialect_mysqldb
from ...util import langhelpers
from ...util import py3k
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
driver = "pymysql"
supports_statement_cache = True
description_encoding = None
# generally, these two values should be both True
# or both False. PyMySQL unicode tests pass all the way back
# to 0.4 either way. See [ticket:3337]
supports_unicode_statements = True
supports_unicode_binds = True
@langhelpers.memoized_property
def supports_server_side_cursors(self):
try:
cursors = __import__("pymysql.cursors").cursors
self._sscursor = cursors.SSCursor
return True
except (ImportError, AttributeError):
return False
@classmethod
def dbapi(cls):
return __import__("pymysql")
def create_connect_args(self, url, _translate_args=None):
if _translate_args is None:
_translate_args = dict(username="user")
return super(MySQLDialect_pymysql, self).create_connect_args(
url, _translate_args=_translate_args
)
def is_disconnect(self, e, connection, cursor):
if super(MySQLDialect_pymysql, self).is_disconnect(
e, connection, cursor
):
return True
elif isinstance(e, self.dbapi.Error):
str_e = str(e).lower()
return (
"already closed" in str_e or "connection was killed" in str_e
)
else:
return False
if py3k:
def _extract_error_code(self, exception):
if isinstance(exception.args[0], Exception):
exception = exception.args[0]
return exception.args[0]
dialect = MySQLDialect_pymysql

View File

@@ -0,0 +1,136 @@
# mysql/pyodbc.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: mysql+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
:url: https://pypi.org/project/pyodbc/
.. note::
The PyODBC for MySQL dialect is **not tested as part of
SQLAlchemy's continuous integration**.
The recommended MySQL dialects are mysqlclient and PyMySQL.
However, if you want to use the mysql+pyodbc dialect and require
full support for ``utf8mb4`` characters (including supplementary
characters like emoji) be sure to use a current release of
MySQL Connector/ODBC and specify the "ANSI" (**not** "Unicode")
version of the driver in your DSN or connection string.
Pass through exact pyodbc connection string::
import urllib
connection_string = (
'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
'SERVER=localhost;'
'PORT=3307;'
'DATABASE=mydb;'
'UID=root;'
'PWD=(whatever);'
'charset=utf8mb4;'
)
params = urllib.parse.quote_plus(connection_string)
connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
""" # noqa
import re
from .base import MySQLDialect
from .base import MySQLExecutionContext
from .types import TIME
from ... import exc
from ... import util
from ...connectors.pyodbc import PyODBCConnector
from ...sql.sqltypes import Time
class _pyodbcTIME(TIME):
def result_processor(self, dialect, coltype):
def process(value):
# pyodbc returns a datetime.time object; no need to convert
return value
return process
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
def get_lastrowid(self):
cursor = self.create_cursor()
cursor.execute("SELECT LAST_INSERT_ID()")
lastrowid = cursor.fetchone()[0]
cursor.close()
return lastrowid
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
supports_statement_cache = True
colspecs = util.update_copy(MySQLDialect.colspecs, {Time: _pyodbcTIME})
supports_unicode_statements = True
execution_ctx_cls = MySQLExecutionContext_pyodbc
pyodbc_driver_name = "MySQL"
def _detect_charset(self, connection):
"""Sniff out the character set in use for connection results."""
# Prefer 'character_set_results' for the current connection over the
# value in the driver. SET NAMES or individual variable SETs will
# change the charset without updating the driver's view of the world.
#
# If it's decided that issuing that sort of SQL leaves you SOL, then
# this can prefer the driver value.
# set this to None as _fetch_setting attempts to use it (None is OK)
self._connection_charset = None
try:
value = self._fetch_setting(connection, "character_set_client")
if value:
return value
except exc.DBAPIError:
pass
util.warn(
"Could not detect the connection character set. "
"Assuming latin1."
)
return "latin1"
def _get_server_version_info(self, connection):
return MySQLDialect._get_server_version_info(self, connection)
def _extract_error_code(self, exception):
m = re.compile(r"\((\d+)\)").search(str(exception.args))
c = m.group(1)
if c:
return int(c)
else:
return None
def on_connect(self):
super_ = super(MySQLDialect_pyodbc, self).on_connect()
def on_connect(conn):
if super_ is not None:
super_(conn)
# declare Unicode encoding for pyodbc as per
# https://github.com/mkleehammer/pyodbc/wiki/Unicode
pyodbc_SQL_CHAR = 1 # pyodbc.SQL_CHAR
pyodbc_SQL_WCHAR = -8 # pyodbc.SQL_WCHAR
conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8")
conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8")
conn.setencoding(encoding="utf-8")
return on_connect
dialect = MySQLDialect_pyodbc

View File

@@ -0,0 +1,558 @@
# mysql/reflection.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import re
from .enumerated import ENUM
from .enumerated import SET
from .types import DATETIME
from .types import TIME
from .types import TIMESTAMP
from ... import log
from ... import types as sqltypes
from ... import util
class ReflectedState(object):
"""Stores raw information about a SHOW CREATE TABLE statement."""
def __init__(self):
self.columns = []
self.table_options = {}
self.table_name = None
self.keys = []
self.fk_constraints = []
self.ck_constraints = []
@log.class_logger
class MySQLTableDefinitionParser(object):
"""Parses the results of a SHOW CREATE TABLE statement."""
def __init__(self, dialect, preparer):
self.dialect = dialect
self.preparer = preparer
self._prep_regexes()
def parse(self, show_create, charset):
state = ReflectedState()
state.charset = charset
for line in re.split(r"\r?\n", show_create):
if line.startswith(" " + self.preparer.initial_quote):
self._parse_column(line, state)
# a regular table options line
elif line.startswith(") "):
self._parse_table_options(line, state)
# an ANSI-mode table options line
elif line == ")":
pass
elif line.startswith("CREATE "):
self._parse_table_name(line, state)
# Not present in real reflection, but may be if
# loading from a file.
elif not line:
pass
else:
type_, spec = self._parse_constraints(line)
if type_ is None:
util.warn("Unknown schema content: %r" % line)
elif type_ == "key":
state.keys.append(spec)
elif type_ == "fk_constraint":
state.fk_constraints.append(spec)
elif type_ == "ck_constraint":
state.ck_constraints.append(spec)
else:
pass
return state
def _parse_constraints(self, line):
"""Parse a KEY or CONSTRAINT line.
:param line: A line of SHOW CREATE TABLE output
"""
# KEY
m = self._re_key.match(line)
if m:
spec = m.groupdict()
# convert columns into name, length pairs
# NOTE: we may want to consider SHOW INDEX as the
# format of indexes in MySQL becomes more complex
spec["columns"] = self._parse_keyexprs(spec["columns"])
if spec["version_sql"]:
m2 = self._re_key_version_sql.match(spec["version_sql"])
if m2 and m2.groupdict()["parser"]:
spec["parser"] = m2.groupdict()["parser"]
if spec["parser"]:
spec["parser"] = self.preparer.unformat_identifiers(
spec["parser"]
)[0]
return "key", spec
# FOREIGN KEY CONSTRAINT
m = self._re_fk_constraint.match(line)
if m:
spec = m.groupdict()
spec["table"] = self.preparer.unformat_identifiers(spec["table"])
spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])]
spec["foreign"] = [
c[0] for c in self._parse_keyexprs(spec["foreign"])
]
return "fk_constraint", spec
# CHECK constraint
m = self._re_ck_constraint.match(line)
if m:
spec = m.groupdict()
return "ck_constraint", spec
# PARTITION and SUBPARTITION
m = self._re_partition.match(line)
if m:
# Punt!
return "partition", line
# No match.
return (None, line)
def _parse_table_name(self, line, state):
"""Extract the table name.
:param line: The first line of SHOW CREATE TABLE
"""
regex, cleanup = self._pr_name
m = regex.match(line)
if m:
state.table_name = cleanup(m.group("name"))
def _parse_table_options(self, line, state):
"""Build a dictionary of all reflected table-level options.
:param line: The final line of SHOW CREATE TABLE output.
"""
options = {}
if not line or line == ")":
pass
else:
rest_of_line = line[:]
for regex, cleanup in self._pr_options:
m = regex.search(rest_of_line)
if not m:
continue
directive, value = m.group("directive"), m.group("val")
if cleanup:
value = cleanup(value)
options[directive.lower()] = value
rest_of_line = regex.sub("", rest_of_line)
for nope in ("auto_increment", "data directory", "index directory"):
options.pop(nope, None)
for opt, val in options.items():
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
def _parse_column(self, line, state):
"""Extract column details.
Falls back to a 'minimal support' variant if full parse fails.
:param line: Any column-bearing line from SHOW CREATE TABLE
"""
spec = None
m = self._re_column.match(line)
if m:
spec = m.groupdict()
spec["full"] = True
else:
m = self._re_column_loose.match(line)
if m:
spec = m.groupdict()
spec["full"] = False
if not spec:
util.warn("Unknown column definition %r" % line)
return
if not spec["full"]:
util.warn("Incomplete reflection of column definition %r" % line)
name, type_, args = spec["name"], spec["coltype"], spec["arg"]
try:
col_type = self.dialect.ischema_names[type_]
except KeyError:
util.warn(
"Did not recognize type '%s' of column '%s'" % (type_, name)
)
col_type = sqltypes.NullType
# Column type positional arguments eg. varchar(32)
if args is None or args == "":
type_args = []
elif args[0] == "'" and args[-1] == "'":
type_args = self._re_csv_str.findall(args)
else:
type_args = [int(v) for v in self._re_csv_int.findall(args)]
# Column type keyword options
type_kw = {}
if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
if type_args:
type_kw["fsp"] = type_args.pop(0)
for kw in ("unsigned", "zerofill"):
if spec.get(kw, False):
type_kw[kw] = True
for kw in ("charset", "collate"):
if spec.get(kw, False):
type_kw[kw] = spec[kw]
if issubclass(col_type, (ENUM, SET)):
type_args = _strip_values(type_args)
if issubclass(col_type, SET) and "" in type_args:
type_kw["retrieve_as_bitwise"] = True
type_instance = col_type(*type_args, **type_kw)
col_kw = {}
# NOT NULL
col_kw["nullable"] = True
# this can be "NULL" in the case of TIMESTAMP
if spec.get("notnull", False) == "NOT NULL":
col_kw["nullable"] = False
# AUTO_INCREMENT
if spec.get("autoincr", False):
col_kw["autoincrement"] = True
elif issubclass(col_type, sqltypes.Integer):
col_kw["autoincrement"] = False
# DEFAULT
default = spec.get("default", None)
if default == "NULL":
# eliminates the need to deal with this later.
default = None
comment = spec.get("comment", None)
if comment is not None:
comment = comment.replace("\\\\", "\\").replace("''", "'")
sqltext = spec.get("generated")
if sqltext is not None:
computed = dict(sqltext=sqltext)
persisted = spec.get("persistence")
if persisted is not None:
computed["persisted"] = persisted == "STORED"
col_kw["computed"] = computed
col_d = dict(
name=name, type=type_instance, default=default, comment=comment
)
col_d.update(col_kw)
state.columns.append(col_d)
def _describe_to_create(self, table_name, columns):
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
DESCRIBE is a much simpler reflection and is sufficient for
reflecting views for runtime use. This method formats DDL
for columns only- keys are omitted.
:param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
SHOW FULL COLUMNS FROM rows must be rearranged for use with
this function.
"""
buffer = []
for row in columns:
(name, col_type, nullable, default, extra) = [
row[i] for i in (0, 1, 2, 4, 5)
]
line = [" "]
line.append(self.preparer.quote_identifier(name))
line.append(col_type)
if not nullable:
line.append("NOT NULL")
if default:
if "auto_increment" in default:
pass
elif col_type.startswith("timestamp") and default.startswith(
"C"
):
line.append("DEFAULT")
line.append(default)
elif default == "NULL":
line.append("DEFAULT")
line.append(default)
else:
line.append("DEFAULT")
line.append("'%s'" % default.replace("'", "''"))
if extra:
line.append(extra)
buffer.append(" ".join(line))
return "".join(
[
(
"CREATE TABLE %s (\n"
% self.preparer.quote_identifier(table_name)
),
",\n".join(buffer),
"\n) ",
]
)
def _parse_keyexprs(self, identifiers):
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
return self._re_keyexprs.findall(identifiers)
def _prep_regexes(self):
"""Pre-compile regular expressions."""
self._re_columns = []
self._pr_options = []
_final = self.preparer.final_quote
quotes = dict(
zip(
("iq", "fq", "esc_fq"),
[
re.escape(s)
for s in (
self.preparer.initial_quote,
_final,
self.preparer._escape_identifier(_final),
)
],
)
)
self._pr_name = _pr_compile(
r"^CREATE (?:\w+ +)?TABLE +"
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes,
self.preparer._unescape_identifier,
)
# `col`,`col2`(32),`col3`(15) DESC
#
self._re_keyexprs = _re_compile(
r"(?:"
r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)"
r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes
)
# 'foo' or 'foo','bar' or 'fo,o','ba''a''r'
self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27")
# 123 or 123,456
self._re_csv_int = _re_compile(r"\d+")
# `colname` <type> [type opts]
# (NOT NULL | NULL)
# DEFAULT ('value' | CURRENT_TIMESTAMP...)
# COMMENT 'comment'
# COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT)
# STORAGE (DISK|MEMORY)
self._re_column = _re_compile(
r" "
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
r"(?P<coltype>\w+)"
r"(?:\((?P<arg>(?:\d+|\d+,\d+|"
r"(?:'(?:''|[^'])*',?)+))\))?"
r"(?: +(?P<unsigned>UNSIGNED))?"
r"(?: +(?P<zerofill>ZEROFILL))?"
r"(?: +CHARACTER SET +(?P<charset>[\w_]+))?"
r"(?: +COLLATE +(?P<collate>[\w_]+))?"
r"(?: +(?P<notnull>(?:NOT )?NULL))?"
r"(?: +DEFAULT +(?P<default>"
r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
r"))?"
r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?)?"
r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
r"(?: +STORAGE +(?P<storage>\w+))?"
r"(?: +(?P<extra>.*))?"
r",?$" % quotes
)
# Fallback, try to parse as little as possible
self._re_column_loose = _re_compile(
r" "
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
r"(?P<coltype>\w+)"
r"(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?"
r".*?(?P<notnull>(?:NOT )NULL)?" % quotes
)
# (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))?
# (`col` (ASC|DESC)?, `col` (ASC|DESC)?)
# KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */
self._re_key = _re_compile(
r" "
r"(?:(?P<type>\S+) )?KEY"
r"(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?"
r"(?: +USING +(?P<using_pre>\S+))?"
r" +\((?P<columns>.+?)\)"
r"(?: +USING +(?P<using_post>\S+))?"
r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?"
r"(?: +WITH PARSER +(?P<parser>\S+))?"
r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
r"(?: +/\*(?P<version_sql>.+)\*/ *)?"
r",?$" % quotes
)
# https://forums.mysql.com/read.php?20,567102,567111#msg-567111
# It means if the MySQL version >= \d+, execute what's in the comment
self._re_key_version_sql = _re_compile(
r"\!\d+ " r"(?: *WITH PARSER +(?P<parser>\S+) *)?"
)
# CONSTRAINT `name` FOREIGN KEY (`local_col`)
# REFERENCES `remote` (`remote_col`)
# MATCH FULL | MATCH PARTIAL | MATCH SIMPLE
# ON DELETE CASCADE ON UPDATE RESTRICT
#
# unique constraints come back as KEYs
kw = quotes.copy()
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
self._re_fk_constraint = _re_compile(
r" "
r"CONSTRAINT +"
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
r"FOREIGN KEY +"
r"\((?P<local>[^\)]+?)\) REFERENCES +"
r"(?P<table>%(iq)s[^%(fq)s]+%(fq)s"
r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +"
r"\((?P<foreign>[^\)]+?)\)"
r"(?: +(?P<match>MATCH \w+))?"
r"(?: +ON DELETE (?P<ondelete>%(on)s))?"
r"(?: +ON UPDATE (?P<onupdate>%(on)s))?" % kw
)
# CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)'
# testing on MariaDB 10.2 shows that the CHECK constraint
# is returned on a line by itself, so to match without worrying
# about parenthesis in the expression we go to the end of the line
self._re_ck_constraint = _re_compile(
r" "
r"CONSTRAINT +"
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
r"CHECK +"
r"\((?P<sqltext>.+)\),?" % kw
)
# PARTITION
#
# punt!
self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)")
# Table-level options (COLLATE, ENGINE, etc.)
# Do the string options first, since they have quoted
# strings we need to get rid of.
for option in _options_of_type_string:
self._add_option_string(option)
for option in (
"ENGINE",
"TYPE",
"AUTO_INCREMENT",
"AVG_ROW_LENGTH",
"CHARACTER SET",
"DEFAULT CHARSET",
"CHECKSUM",
"COLLATE",
"DELAY_KEY_WRITE",
"INSERT_METHOD",
"MAX_ROWS",
"MIN_ROWS",
"PACK_KEYS",
"ROW_FORMAT",
"KEY_BLOCK_SIZE",
):
self._add_option_word(option)
self._add_option_regex("UNION", r"\([^\)]+\)")
self._add_option_regex("TABLESPACE", r".*? STORAGE DISK")
self._add_option_regex(
"RAID_TYPE",
r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+",
)
_optional_equals = r"(?:\s*(?:=\s*)|\s+)"
def _add_option_string(self, directive):
regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(
_pr_compile(
regex, lambda v: v.replace("\\\\", "\\").replace("''", "'")
)
)
def _add_option_word(self, directive):
regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
re.escape(directive),
self._optional_equals,
)
self._pr_options.append(_pr_compile(regex))
def _add_option_regex(self, directive, regex):
regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
re.escape(directive),
self._optional_equals,
regex,
)
self._pr_options.append(_pr_compile(regex))
_options_of_type_string = (
"COMMENT",
"DATA DIRECTORY",
"INDEX DIRECTORY",
"PASSWORD",
"CONNECTION",
)
def _pr_compile(regex, cleanup=None):
"""Prepare a 2-tuple of compiled regex and callable."""
return (_re_compile(regex), cleanup)
def _re_compile(regex):
"""Compile a string to regex, I and UNICODE."""
return re.compile(regex, re.I | re.UNICODE)
def _strip_values(values):
"Strip reflected values quotes"
strip_values = []
for a in values:
if a[0:1] == '"' or a[0:1] == "'":
# strip enclosing quotes and unquote interior
a = a[1:-1].replace(a[0] * 2, a[0])
strip_values.append(a)
return strip_values

View File

@@ -0,0 +1,564 @@
# mysql/reserved_words.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# generated using:
# https://gist.github.com/kkirsche/4f31f2153ed7a3248be1ec44ca6ddbc9
#
# https://mariadb.com/kb/en/reserved-words/
# includes: Reserved Words, Oracle Mode (separate set unioned)
# excludes: Exceptions, Function Names
RESERVED_WORDS_MARIADB = {
"accessible",
"add",
"all",
"alter",
"analyze",
"and",
"as",
"asc",
"asensitive",
"before",
"between",
"bigint",
"binary",
"blob",
"both",
"by",
"call",
"cascade",
"case",
"change",
"char",
"character",
"check",
"collate",
"column",
"condition",
"constraint",
"continue",
"convert",
"create",
"cross",
"current_date",
"current_role",
"current_time",
"current_timestamp",
"current_user",
"cursor",
"database",
"databases",
"day_hour",
"day_microsecond",
"day_minute",
"day_second",
"dec",
"decimal",
"declare",
"default",
"delayed",
"delete",
"desc",
"describe",
"deterministic",
"distinct",
"distinctrow",
"div",
"do_domain_ids",
"double",
"drop",
"dual",
"each",
"else",
"elseif",
"enclosed",
"escaped",
"except",
"exists",
"exit",
"explain",
"false",
"fetch",
"float",
"float4",
"float8",
"for",
"force",
"foreign",
"from",
"fulltext",
"general",
"grant",
"group",
"having",
"high_priority",
"hour_microsecond",
"hour_minute",
"hour_second",
"if",
"ignore",
"ignore_domain_ids",
"ignore_server_ids",
"in",
"index",
"infile",
"inner",
"inout",
"insensitive",
"insert",
"int",
"int1",
"int2",
"int3",
"int4",
"int8",
"integer",
"intersect",
"interval",
"into",
"is",
"iterate",
"join",
"key",
"keys",
"kill",
"leading",
"leave",
"left",
"like",
"limit",
"linear",
"lines",
"load",
"localtime",
"localtimestamp",
"lock",
"long",
"longblob",
"longtext",
"loop",
"low_priority",
"master_heartbeat_period",
"master_ssl_verify_server_cert",
"match",
"maxvalue",
"mediumblob",
"mediumint",
"mediumtext",
"middleint",
"minute_microsecond",
"minute_second",
"mod",
"modifies",
"natural",
"no_write_to_binlog",
"not",
"null",
"numeric",
"offset",
"on",
"optimize",
"option",
"optionally",
"or",
"order",
"out",
"outer",
"outfile",
"over",
"page_checksum",
"parse_vcol_expr",
"partition",
"position",
"precision",
"primary",
"procedure",
"purge",
"range",
"read",
"read_write",
"reads",
"real",
"recursive",
"ref_system_id",
"references",
"regexp",
"release",
"rename",
"repeat",
"replace",
"require",
"resignal",
"restrict",
"return",
"returning",
"revoke",
"right",
"rlike",
"rows",
"schema",
"schemas",
"second_microsecond",
"select",
"sensitive",
"separator",
"set",
"show",
"signal",
"slow",
"smallint",
"spatial",
"specific",
"sql",
"sql_big_result",
"sql_calc_found_rows",
"sql_small_result",
"sqlexception",
"sqlstate",
"sqlwarning",
"ssl",
"starting",
"stats_auto_recalc",
"stats_persistent",
"stats_sample_pages",
"straight_join",
"table",
"terminated",
"then",
"tinyblob",
"tinyint",
"tinytext",
"to",
"trailing",
"trigger",
"true",
"undo",
"union",
"unique",
"unlock",
"unsigned",
"update",
"usage",
"use",
"using",
"utc_date",
"utc_time",
"utc_timestamp",
"values",
"varbinary",
"varchar",
"varcharacter",
"varying",
"when",
"where",
"while",
"window",
"with",
"write",
"xor",
"year_month",
"zerofill",
}.union(
{
"body",
"elsif",
"goto",
"history",
"others",
"package",
"period",
"raise",
"rowtype",
"system",
"system_time",
"versioning",
"without",
}
)
# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
# includes: MySQL x.0 Keywords and Reserved Words
# excludes: MySQL x.0 New Keywords and Reserved Words,
# MySQL x.0 Removed Keywords and Reserved Words
RESERVED_WORDS_MYSQL = {
"accessible",
"add",
"admin",
"all",
"alter",
"analyze",
"and",
"array",
"as",
"asc",
"asensitive",
"before",
"between",
"bigint",
"binary",
"blob",
"both",
"by",
"call",
"cascade",
"case",
"change",
"char",
"character",
"check",
"collate",
"column",
"condition",
"constraint",
"continue",
"convert",
"create",
"cross",
"cube",
"cume_dist",
"current_date",
"current_time",
"current_timestamp",
"current_user",
"cursor",
"database",
"databases",
"day_hour",
"day_microsecond",
"day_minute",
"day_second",
"dec",
"decimal",
"declare",
"default",
"delayed",
"delete",
"dense_rank",
"desc",
"describe",
"deterministic",
"distinct",
"distinctrow",
"div",
"double",
"drop",
"dual",
"each",
"else",
"elseif",
"empty",
"enclosed",
"escaped",
"except",
"exists",
"exit",
"explain",
"false",
"fetch",
"first_value",
"float",
"float4",
"float8",
"for",
"force",
"foreign",
"from",
"fulltext",
"function",
"general",
"generated",
"get",
"get_master_public_key",
"grant",
"group",
"grouping",
"groups",
"having",
"high_priority",
"hour_microsecond",
"hour_minute",
"hour_second",
"if",
"ignore",
"ignore_server_ids",
"in",
"index",
"infile",
"inner",
"inout",
"insensitive",
"insert",
"int",
"int1",
"int2",
"int3",
"int4",
"int8",
"integer",
"interval",
"into",
"io_after_gtids",
"io_before_gtids",
"is",
"iterate",
"join",
"json_table",
"key",
"keys",
"kill",
"lag",
"last_value",
"lateral",
"lead",
"leading",
"leave",
"left",
"like",
"limit",
"linear",
"lines",
"load",
"localtime",
"localtimestamp",
"lock",
"long",
"longblob",
"longtext",
"loop",
"low_priority",
"master_bind",
"master_heartbeat_period",
"master_ssl_verify_server_cert",
"match",
"maxvalue",
"mediumblob",
"mediumint",
"mediumtext",
"member",
"middleint",
"minute_microsecond",
"minute_second",
"mod",
"modifies",
"natural",
"no_write_to_binlog",
"not",
"nth_value",
"ntile",
"null",
"numeric",
"of",
"on",
"optimize",
"optimizer_costs",
"option",
"optionally",
"or",
"order",
"out",
"outer",
"outfile",
"over",
"parse_gcol_expr",
"partition",
"percent_rank",
"persist",
"persist_only",
"precision",
"primary",
"procedure",
"purge",
"range",
"rank",
"read",
"read_write",
"reads",
"real",
"recursive",
"references",
"regexp",
"release",
"rename",
"repeat",
"replace",
"require",
"resignal",
"restrict",
"return",
"revoke",
"right",
"rlike",
"role",
"row",
"row_number",
"rows",
"schema",
"schemas",
"second_microsecond",
"select",
"sensitive",
"separator",
"set",
"show",
"signal",
"slow",
"smallint",
"spatial",
"specific",
"sql",
"sql_after_gtids",
"sql_before_gtids",
"sql_big_result",
"sql_calc_found_rows",
"sql_small_result",
"sqlexception",
"sqlstate",
"sqlwarning",
"ssl",
"starting",
"stored",
"straight_join",
"system",
"table",
"terminated",
"then",
"tinyblob",
"tinyint",
"tinytext",
"to",
"trailing",
"trigger",
"true",
"undo",
"union",
"unique",
"unlock",
"unsigned",
"update",
"usage",
"use",
"using",
"utc_date",
"utc_time",
"utc_timestamp",
"values",
"varbinary",
"varchar",
"varcharacter",
"varying",
"virtual",
"when",
"where",
"while",
"window",
"with",
"write",
"xor",
"year_month",
"zerofill",
}

View File

@@ -0,0 +1,773 @@
# mysql/types.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import datetime
from ... import exc
from ... import types as sqltypes
from ... import util
class _NumericType(object):
"""Base for MySQL numeric types.
This is the base both for NUMERIC as well as INTEGER, hence
it's a mixin.
"""
def __init__(self, unsigned=False, zerofill=False, **kw):
self.unsigned = unsigned
self.zerofill = zerofill
super(_NumericType, self).__init__(**kw)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_NumericType, sqltypes.Numeric]
)
class _FloatType(_NumericType, sqltypes.Float):
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
if isinstance(self, (REAL, DOUBLE)) and (
(precision is None and scale is not None)
or (precision is not None and scale is None)
):
raise exc.ArgumentError(
"You must specify both precision and scale or omit "
"both altogether."
)
super(_FloatType, self).__init__(
precision=precision, asdecimal=asdecimal, **kw
)
self.scale = scale
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
)
class _IntegerType(_NumericType, sqltypes.Integer):
def __init__(self, display_width=None, **kw):
self.display_width = display_width
super(_IntegerType, self).__init__(**kw)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
)
class _StringType(sqltypes.String):
"""Base for MySQL string types."""
def __init__(
self,
charset=None,
collation=None,
ascii=False, # noqa
binary=False,
unicode=False,
national=False,
**kw
):
self.charset = charset
# allow collate= or collation=
kw.setdefault("collation", kw.pop("collate", collation))
self.ascii = ascii
self.unicode = unicode
self.binary = binary
self.national = national
super(_StringType, self).__init__(**kw)
def __repr__(self):
return util.generic_repr(
self, to_inspect=[_StringType, sqltypes.String]
)
class _MatchType(sqltypes.Float, sqltypes.MatchType):
def __init__(self, **kw):
# TODO: float arguments?
sqltypes.Float.__init__(self)
sqltypes.MatchType.__init__(self)
class NUMERIC(_NumericType, sqltypes.NUMERIC):
"""MySQL NUMERIC type."""
__visit_name__ = "NUMERIC"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a NUMERIC.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(NUMERIC, self).__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
class DECIMAL(_NumericType, sqltypes.DECIMAL):
"""MySQL DECIMAL type."""
__visit_name__ = "DECIMAL"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DECIMAL.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(DECIMAL, self).__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
class DOUBLE(_FloatType):
"""MySQL DOUBLE type."""
__visit_name__ = "DOUBLE"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a DOUBLE.
.. note::
The :class:`.DOUBLE` type by default converts from float
to Decimal, using a truncation that defaults to 10 digits.
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
to change this scale, or ``asdecimal=False`` to return values
directly as Python floating points.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(DOUBLE, self).__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
class REAL(_FloatType, sqltypes.REAL):
"""MySQL REAL type."""
__visit_name__ = "REAL"
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
"""Construct a REAL.
.. note::
The :class:`.REAL` type by default converts from float
to Decimal, using a truncation that defaults to 10 digits.
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
to change this scale, or ``asdecimal=False`` to return values
directly as Python floating points.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(REAL, self).__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
class FLOAT(_FloatType, sqltypes.FLOAT):
"""MySQL FLOAT type."""
__visit_name__ = "FLOAT"
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
"""Construct a FLOAT.
:param precision: Total digits in this number. If scale and precision
are both None, values are stored to limits allowed by the server.
:param scale: The number of digits after the decimal point.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(FLOAT, self).__init__(
precision=precision, scale=scale, asdecimal=asdecimal, **kw
)
def bind_processor(self, dialect):
return None
class INTEGER(_IntegerType, sqltypes.INTEGER):
"""MySQL INTEGER type."""
__visit_name__ = "INTEGER"
def __init__(self, display_width=None, **kw):
"""Construct an INTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(INTEGER, self).__init__(display_width=display_width, **kw)
class BIGINT(_IntegerType, sqltypes.BIGINT):
"""MySQL BIGINTEGER type."""
__visit_name__ = "BIGINT"
def __init__(self, display_width=None, **kw):
"""Construct a BIGINTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(BIGINT, self).__init__(display_width=display_width, **kw)
class MEDIUMINT(_IntegerType):
"""MySQL MEDIUMINTEGER type."""
__visit_name__ = "MEDIUMINT"
def __init__(self, display_width=None, **kw):
"""Construct a MEDIUMINTEGER
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(MEDIUMINT, self).__init__(display_width=display_width, **kw)
class TINYINT(_IntegerType):
"""MySQL TINYINT type."""
__visit_name__ = "TINYINT"
def __init__(self, display_width=None, **kw):
"""Construct a TINYINT.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(TINYINT, self).__init__(display_width=display_width, **kw)
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
"""MySQL SMALLINTEGER type."""
__visit_name__ = "SMALLINT"
def __init__(self, display_width=None, **kw):
"""Construct a SMALLINTEGER.
:param display_width: Optional, maximum display width for this number.
:param unsigned: a boolean, optional.
:param zerofill: Optional. If true, values will be stored as strings
left-padded with zeros. Note that this does not effect the values
returned by the underlying database API, which continue to be
numeric.
"""
super(SMALLINT, self).__init__(display_width=display_width, **kw)
class BIT(sqltypes.TypeEngine):
"""MySQL BIT type.
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a
MSTinyInteger() type.
"""
__visit_name__ = "BIT"
def __init__(self, length=None):
"""Construct a BIT.
:param length: Optional, number of bits.
"""
self.length = length
def result_processor(self, dialect, coltype):
"""Convert a MySQL's 64 bit, variable length binary string to a long.
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
already do this, so this logic should be moved to those dialects.
"""
def process(value):
if value is not None:
v = 0
for i in value:
if not isinstance(i, int):
i = ord(i) # convert byte to int on Python 2
v = v << 8 | i
return v
return value
return process
class TIME(sqltypes.TIME):
"""MySQL TIME type."""
__visit_name__ = "TIME"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIME type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the TIME type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
"""
super(TIME, self).__init__(timezone=timezone)
self.fsp = fsp
def result_processor(self, dialect, coltype):
time = datetime.time
def process(value):
# convert from a timedelta value
if value is not None:
microseconds = value.microseconds
seconds = value.seconds
minutes = seconds // 60
return time(
minutes // 60,
minutes % 60,
seconds - minutes * 60,
microsecond=microseconds,
)
else:
return None
return process
class TIMESTAMP(sqltypes.TIMESTAMP):
"""MySQL TIMESTAMP type."""
__visit_name__ = "TIMESTAMP"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL TIMESTAMP type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6.4 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the TIMESTAMP type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
"""
super(TIMESTAMP, self).__init__(timezone=timezone)
self.fsp = fsp
class DATETIME(sqltypes.DATETIME):
"""MySQL DATETIME type."""
__visit_name__ = "DATETIME"
def __init__(self, timezone=False, fsp=None):
"""Construct a MySQL DATETIME type.
:param timezone: not used by the MySQL dialect.
:param fsp: fractional seconds precision value.
MySQL 5.6.4 supports storage of fractional seconds;
this parameter will be used when emitting DDL
for the DATETIME type.
.. note::
DBAPI driver support for fractional seconds may
be limited; current support includes
MySQL Connector/Python.
"""
super(DATETIME, self).__init__(timezone=timezone)
self.fsp = fsp
class YEAR(sqltypes.TypeEngine):
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
__visit_name__ = "YEAR"
def __init__(self, display_width=None):
self.display_width = display_width
class TEXT(_StringType, sqltypes.TEXT):
"""MySQL TEXT type, for text up to 2^16 characters."""
__visit_name__ = "TEXT"
def __init__(self, length=None, **kw):
"""Construct a TEXT.
:param length: Optional, if provided the server may optimize storage
by substituting the smallest TEXT type sufficient to store
``length`` characters.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(TEXT, self).__init__(length=length, **kw)
class TINYTEXT(_StringType):
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
__visit_name__ = "TINYTEXT"
def __init__(self, **kwargs):
"""Construct a TINYTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(TINYTEXT, self).__init__(**kwargs)
class MEDIUMTEXT(_StringType):
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
__visit_name__ = "MEDIUMTEXT"
def __init__(self, **kwargs):
"""Construct a MEDIUMTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(MEDIUMTEXT, self).__init__(**kwargs)
class LONGTEXT(_StringType):
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
__visit_name__ = "LONGTEXT"
def __init__(self, **kwargs):
"""Construct a LONGTEXT.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(LONGTEXT, self).__init__(**kwargs)
class VARCHAR(_StringType, sqltypes.VARCHAR):
"""MySQL VARCHAR type, for variable-length character data."""
__visit_name__ = "VARCHAR"
def __init__(self, length=None, **kwargs):
"""Construct a VARCHAR.
:param charset: Optional, a column-level character set for this string
value. Takes precedence to 'ascii' or 'unicode' short-hand.
:param collation: Optional, a column-level collation for this string
value. Takes precedence to 'binary' short-hand.
:param ascii: Defaults to False: short-hand for the ``latin1``
character set, generates ASCII in schema.
:param unicode: Defaults to False: short-hand for the ``ucs2``
character set, generates UNICODE in schema.
:param national: Optional. If true, use the server's configured
national character set.
:param binary: Defaults to False: short-hand, pick the binary
collation type that matches the column's character set. Generates
BINARY in schema. This does not affect the type of data stored,
only the collation of character data.
"""
super(VARCHAR, self).__init__(length=length, **kwargs)
class CHAR(_StringType, sqltypes.CHAR):
"""MySQL CHAR type, for fixed-length character data."""
__visit_name__ = "CHAR"
def __init__(self, length=None, **kwargs):
"""Construct a CHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
super(CHAR, self).__init__(length=length, **kwargs)
@classmethod
def _adapt_string_for_cast(self, type_):
# copy the given string type into a CHAR
# for the purposes of rendering a CAST expression
type_ = sqltypes.to_instance(type_)
if isinstance(type_, sqltypes.CHAR):
return type_
elif isinstance(type_, _StringType):
return CHAR(
length=type_.length,
charset=type_.charset,
collation=type_.collation,
ascii=type_.ascii,
binary=type_.binary,
unicode=type_.unicode,
national=False, # not supported in CAST
)
else:
return CHAR(length=type_.length)
class NVARCHAR(_StringType, sqltypes.NVARCHAR):
"""MySQL NVARCHAR type.
For variable-length character data in the server's configured national
character set.
"""
__visit_name__ = "NVARCHAR"
def __init__(self, length=None, **kwargs):
"""Construct an NVARCHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
kwargs["national"] = True
super(NVARCHAR, self).__init__(length=length, **kwargs)
class NCHAR(_StringType, sqltypes.NCHAR):
"""MySQL NCHAR type.
For fixed-length character data in the server's configured national
character set.
"""
__visit_name__ = "NCHAR"
def __init__(self, length=None, **kwargs):
"""Construct an NCHAR.
:param length: Maximum data length, in characters.
:param binary: Optional, use the default binary collation for the
national character set. This does not affect the type of data
stored, use a BINARY type for binary data.
:param collation: Optional, request a particular collation. Must be
compatible with the national character set.
"""
kwargs["national"] = True
super(NCHAR, self).__init__(length=length, **kwargs)
class TINYBLOB(sqltypes._Binary):
"""MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
__visit_name__ = "TINYBLOB"
class MEDIUMBLOB(sqltypes._Binary):
"""MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
__visit_name__ = "MEDIUMBLOB"
class LONGBLOB(sqltypes._Binary):
"""MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
__visit_name__ = "LONGBLOB"

View File

@@ -0,0 +1,58 @@
# oracle/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from . import base # noqa
from . import cx_oracle # noqa
from .base import BFILE
from .base import BINARY_DOUBLE
from .base import BINARY_FLOAT
from .base import BLOB
from .base import CHAR
from .base import CLOB
from .base import DATE
from .base import DOUBLE_PRECISION
from .base import FLOAT
from .base import INTERVAL
from .base import LONG
from .base import NCHAR
from .base import NCLOB
from .base import NUMBER
from .base import NVARCHAR
from .base import NVARCHAR2
from .base import RAW
from .base import ROWID
from .base import TIMESTAMP
from .base import VARCHAR
from .base import VARCHAR2
base.dialect = dialect = cx_oracle.dialect
__all__ = (
"VARCHAR",
"NVARCHAR",
"CHAR",
"NCHAR",
"DATE",
"NUMBER",
"BLOB",
"BFILE",
"CLOB",
"NCLOB",
"TIMESTAMP",
"RAW",
"FLOAT",
"DOUBLE_PRECISION",
"BINARY_DOUBLE",
"BINARY_FLOAT",
"LONG",
"dialect",
"INTERVAL",
"VARCHAR2",
"NVARCHAR2",
"ROWID",
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,160 @@
from ... import create_engine
from ... import exc
from ...engine import url as sa_url
from ...testing.provision import configure_follower
from ...testing.provision import create_db
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
from ...testing.provision import set_default_schema_on_connection
from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
@create_db.for_db("oracle")
def _oracle_create_db(cfg, eng, ident):
# NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
# similar, so that the default tablespace is not "system"; reflection will
# fail otherwise
with eng.begin() as conn:
conn.exec_driver_sql("create user %s identified by xe" % ident)
conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident)
conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident)
conn.exec_driver_sql("grant dba to %s" % (ident,))
conn.exec_driver_sql("grant unlimited tablespace to %s" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident)
@configure_follower.for_db("oracle")
def _oracle_configure_follower(config, ident):
config.test_schema = "%s_ts1" % ident
config.test_schema_2 = "%s_ts2" % ident
def _ora_drop_ignore(conn, dbname):
try:
conn.exec_driver_sql("drop user %s cascade" % dbname)
log.info("Reaped db: %s", dbname)
return True
except exc.DatabaseError as err:
log.warning("couldn't drop db: %s", err)
return False
@drop_db.for_db("oracle")
def _oracle_drop_db(cfg, eng, ident):
with eng.begin() as conn:
# cx_Oracle seems to occasionally leak open connections when a large
# suite it run, even if we confirm we have zero references to
# connection objects.
# while there is a "kill session" command in Oracle,
# it unfortunately does not release the connection sufficiently.
_ora_drop_ignore(conn, ident)
_ora_drop_ignore(conn, "%s_ts1" % ident)
_ora_drop_ignore(conn, "%s_ts2" % ident)
@stop_test_class_outside_fixtures.for_db("oracle")
def stop_test_class_outside_fixtures(config, db, cls):
try:
with db.begin() as conn:
# run magic command to get rid of identity sequences
# https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
conn.exec_driver_sql("purge recyclebin")
except exc.DatabaseError as err:
log.warning("purge recyclebin command failed: %s", err)
# clear statement cache on all connections that were used
# https://github.com/oracle/python-cx_Oracle/issues/519
for cx_oracle_conn in _all_conns:
try:
sc = cx_oracle_conn.stmtcachesize
except db.dialect.dbapi.InterfaceError:
# connection closed
pass
else:
cx_oracle_conn.stmtcachesize = 0
cx_oracle_conn.stmtcachesize = sc
_all_conns.clear()
_all_conns = set()
@post_configure_engine.for_db("oracle")
def _oracle_post_configure_engine(url, engine, follower_ident):
from sqlalchemy import event
@event.listens_for(engine, "checkout")
def checkout(dbapi_con, con_record, con_proxy):
_all_conns.add(dbapi_con)
@event.listens_for(engine, "checkin")
def checkin(dbapi_connection, connection_record):
# work around cx_Oracle issue:
# https://github.com/oracle/python-cx_Oracle/issues/530
# invalidate oracle connections that had 2pc set up
if "cx_oracle_xid" in connection_record.info:
connection_record.invalidate()
@run_reap_dbs.for_db("oracle")
def _reap_oracle_dbs(url, idents):
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
with eng.begin() as conn:
log.info("identifiers in file: %s", ", ".join(idents))
to_reap = conn.exec_driver_sql(
"select u.username from all_users u where username "
"like 'TEST_%' and not exists (select username "
"from v$session where username=u.username)"
)
all_names = {username.lower() for (username,) in to_reap}
to_drop = set()
for name in all_names:
if name.endswith("_ts1") or name.endswith("_ts2"):
continue
elif name in idents:
to_drop.add(name)
if "%s_ts1" % name in all_names:
to_drop.add("%s_ts1" % name)
if "%s_ts2" % name in all_names:
to_drop.add("%s_ts2" % name)
dropped = total = 0
for total, username in enumerate(to_drop, 1):
if _ora_drop_ignore(conn, username):
dropped += 1
log.info(
"Dropped %d out of %d stale databases detected", dropped, total
)
@follower_url_from_main.for_db("oracle")
def _oracle_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
return url.set(username=ident, password="xe")
@temp_table_keyword_args.for_db("oracle")
def _oracle_temp_table_keyword_args(cfg, eng):
return {
"prefixes": ["GLOBAL TEMPORARY"],
"oracle_on_commit": "PRESERVE ROWS",
}
@set_default_schema_on_connection.for_db("oracle")
def _oracle_set_default_schema_on_connection(
cfg, dbapi_connection, schema_name
):
cursor = dbapi_connection.cursor()
cursor.execute("ALTER SESSION SET CURRENT_SCHEMA=%s" % schema_name)
cursor.close()

View File

@@ -0,0 +1,117 @@
# postgresql/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from . import base
from . import pg8000 # noqa
from . import psycopg2 # noqa
from . import psycopg2cffi # noqa
from . import pygresql # noqa
from . import pypostgresql # noqa
from .array import All
from .array import Any
from .array import ARRAY
from .array import array
from .base import BIGINT
from .base import BIT
from .base import BOOLEAN
from .base import BYTEA
from .base import CHAR
from .base import CIDR
from .base import CreateEnumType
from .base import DATE
from .base import DOUBLE_PRECISION
from .base import DropEnumType
from .base import ENUM
from .base import FLOAT
from .base import INET
from .base import INTEGER
from .base import INTERVAL
from .base import MACADDR
from .base import MONEY
from .base import NUMERIC
from .base import OID
from .base import REAL
from .base import REGCLASS
from .base import SMALLINT
from .base import TEXT
from .base import TIME
from .base import TIMESTAMP
from .base import TSVECTOR
from .base import UUID
from .base import VARCHAR
from .dml import Insert
from .dml import insert
from .ext import aggregate_order_by
from .ext import array_agg
from .ext import ExcludeConstraint
from .hstore import HSTORE
from .hstore import hstore
from .json import JSON
from .json import JSONB
from .ranges import DATERANGE
from .ranges import INT4RANGE
from .ranges import INT8RANGE
from .ranges import NUMRANGE
from .ranges import TSRANGE
from .ranges import TSTZRANGE
from ...util import compat
if compat.py3k:
from . import asyncpg # noqa
base.dialect = dialect = psycopg2.dialect
__all__ = (
"INTEGER",
"BIGINT",
"SMALLINT",
"VARCHAR",
"CHAR",
"TEXT",
"NUMERIC",
"FLOAT",
"REAL",
"INET",
"CIDR",
"UUID",
"BIT",
"MACADDR",
"MONEY",
"OID",
"REGCLASS",
"DOUBLE_PRECISION",
"TIMESTAMP",
"TIME",
"DATE",
"BYTEA",
"BOOLEAN",
"INTERVAL",
"ARRAY",
"ENUM",
"dialect",
"array",
"HSTORE",
"hstore",
"INT4RANGE",
"INT8RANGE",
"NUMRANGE",
"DATERANGE",
"TSVECTOR",
"TSRANGE",
"TSTZRANGE",
"JSON",
"JSONB",
"Any",
"All",
"DropEnumType",
"CreateEnumType",
"ExcludeConstraint",
"aggregate_order_by",
"array_agg",
"insert",
"Insert",
)

View File

@@ -0,0 +1,413 @@
# postgresql/array.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import re
from ... import types as sqltypes
from ... import util
from ...sql import coercions
from ...sql import expression
from ...sql import operators
from ...sql import roles
def Any(other, arrexpr, operator=operators.eq):
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
See that method for details.
"""
return arrexpr.any(other, operator)
def All(other, arrexpr, operator=operators.eq):
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
See that method for details.
"""
return arrexpr.all(other, operator)
class array(expression.ClauseList, expression.ColumnElement):
"""A PostgreSQL ARRAY literal.
This is used to produce ARRAY literals in SQL expressions, e.g.::
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects import postgresql
from sqlalchemy import select, func
stmt = select(array([1,2]) + array([3,4,5]))
print(stmt.compile(dialect=postgresql.dialect()))
Produces the SQL::
SELECT ARRAY[%(param_1)s, %(param_2)s] ||
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
An instance of :class:`.array` will always have the datatype
:class:`_types.ARRAY`. The "inner" type of the array is inferred from
the values present, unless the ``type_`` keyword argument is passed::
array(['foo', 'bar'], type_=CHAR)
Multidimensional arrays are produced by nesting :class:`.array` constructs.
The dimensionality of the final :class:`_types.ARRAY`
type is calculated by
recursively adding the dimensions of the inner :class:`_types.ARRAY`
type::
stmt = select(
array([
array([1, 2]), array([3, 4]), array([column('q'), column('x')])
])
)
print(stmt.compile(dialect=postgresql.dialect()))
Produces::
SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
.. versionadded:: 1.3.6 added support for multidimensional array literals
.. seealso::
:class:`_postgresql.ARRAY`
"""
__visit_name__ = "array"
stringify_dialect = "postgresql"
inherit_cache = True
def __init__(self, clauses, **kw):
clauses = [
coercions.expect(roles.ExpressionElementRole, c) for c in clauses
]
super(array, self).__init__(*clauses, **kw)
self._type_tuple = [arg.type for arg in clauses]
main_type = kw.pop(
"type_",
self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE,
)
if isinstance(main_type, ARRAY):
self.type = ARRAY(
main_type.item_type,
dimensions=main_type.dimensions + 1
if main_type.dimensions is not None
else 2,
)
else:
self.type = ARRAY(main_type)
@property
def _select_iterable(self):
return (self,)
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
if _assume_scalar or operator is operators.getitem:
return expression.BindParameter(
None,
obj,
_compared_to_operator=operator,
type_=type_,
_compared_to_type=self.type,
unique=True,
)
else:
return array(
[
self._bind_param(
operator, o, _assume_scalar=True, type_=type_
)
for o in obj
]
)
def self_group(self, against=None):
if against in (operators.any_op, operators.all_op, operators.getitem):
return expression.Grouping(self)
else:
return self
CONTAINS = operators.custom_op("@>", precedence=5, is_comparison=True)
CONTAINED_BY = operators.custom_op("<@", precedence=5, is_comparison=True)
OVERLAP = operators.custom_op("&&", precedence=5, is_comparison=True)
class ARRAY(sqltypes.ARRAY):
"""PostgreSQL ARRAY type.
.. versionchanged:: 1.1 The :class:`_postgresql.ARRAY` type is now
a subclass of the core :class:`_types.ARRAY` type.
The :class:`_postgresql.ARRAY` type is constructed in the same way
as the core :class:`_types.ARRAY` type; a member type is required, and a
number of dimensions is recommended if the type is to be used for more
than one dimension::
from sqlalchemy.dialects import postgresql
mytable = Table("mytable", metadata,
Column("data", postgresql.ARRAY(Integer, dimensions=2))
)
The :class:`_postgresql.ARRAY` type provides all operations defined on the
core :class:`_types.ARRAY` type, including support for "dimensions",
indexed access, and simple matching such as
:meth:`.types.ARRAY.Comparator.any` and
:meth:`.types.ARRAY.Comparator.all`. :class:`_postgresql.ARRAY`
class also
provides PostgreSQL-specific methods for containment operations, including
:meth:`.postgresql.ARRAY.Comparator.contains`
:meth:`.postgresql.ARRAY.Comparator.contained_by`, and
:meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.::
mytable.c.data.contains([1, 2])
The :class:`_postgresql.ARRAY` type may not be supported on all
PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
Additionally, the :class:`_postgresql.ARRAY`
type does not work directly in
conjunction with the :class:`.ENUM` type. For a workaround, see the
special type at :ref:`postgresql_array_of_enum`.
.. seealso::
:class:`_types.ARRAY` - base array type
:class:`_postgresql.array` - produces a literal array value.
"""
class Comparator(sqltypes.ARRAY.Comparator):
"""Define comparison operations for :class:`_types.ARRAY`.
Note that these operations are in addition to those provided
by the base :class:`.types.ARRAY.Comparator` class, including
:meth:`.types.ARRAY.Comparator.any` and
:meth:`.types.ARRAY.Comparator.all`.
"""
def contains(self, other, **kwargs):
"""Boolean expression. Test if elements are a superset of the
elements of the argument array expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if elements are a proper subset of the
elements of the argument array expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def overlap(self, other):
"""Boolean expression. Test if array has elements in common with
an argument array expression.
"""
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
comparator_factory = Comparator
def __init__(
self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
):
"""Construct an ARRAY.
E.g.::
Column('myarray', ARRAY(Integer))
Arguments are:
:param item_type: The data type of items of this array. Note that
dimensionality is irrelevant here, so multi-dimensional arrays like
``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
``ARRAY(ARRAY(Integer))`` or such.
:param as_tuple=False: Specify whether return results
should be converted to tuples from lists. DBAPIs such
as psycopg2 return lists by default. When tuples are
returned, the results are hashable.
:param dimensions: if non-None, the ARRAY will assume a fixed
number of dimensions. This will cause the DDL emitted for this
ARRAY to include the exact number of bracket clauses ``[]``,
and will also optimize the performance of the type overall.
Note that PG arrays are always implicitly "non-dimensioned",
meaning they can store any number of dimensions no matter how
they were declared.
:param zero_indexes=False: when True, index values will be converted
between Python zero-based and PostgreSQL one-based indexes, e.g.
a value of one will be added to all index values before passing
to the database.
.. versionadded:: 0.9.5
"""
if isinstance(item_type, ARRAY):
raise ValueError(
"Do not nest ARRAY types; ARRAY(basetype) "
"handles multi-dimensional arrays of basetype"
)
if isinstance(item_type, type):
item_type = item_type()
self.item_type = item_type
self.as_tuple = as_tuple
self.dimensions = dimensions
self.zero_indexes = zero_indexes
@property
def hashable(self):
return self.as_tuple
@property
def python_type(self):
return list
def compare_values(self, x, y):
return x == y
def _proc_array(self, arr, itemproc, dim, collection):
if dim is None:
arr = list(arr)
if (
dim == 1
or dim is None
and (
# this has to be (list, tuple), or at least
# not hasattr('__iter__'), since Py3K strings
# etc. have __iter__
not arr
or not isinstance(arr[0], (list, tuple))
)
):
if itemproc:
return collection(itemproc(x) for x in arr)
else:
return collection(arr)
else:
return collection(
self._proc_array(
x,
itemproc,
dim - 1 if dim is not None else None,
collection,
)
for x in arr
)
@util.memoized_property
def _against_native_enum(self):
return (
isinstance(self.item_type, sqltypes.Enum)
and self.item_type.native_enum
)
def bind_expression(self, bindvalue):
return bindvalue
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
dialect
)
def process(value):
if value is None:
return value
else:
return self._proc_array(
value, item_proc, self.dimensions, list
)
return process
def result_processor(self, dialect, coltype):
item_proc = self.item_type.dialect_impl(dialect).result_processor(
dialect, coltype
)
def process(value):
if value is None:
return value
else:
return self._proc_array(
value,
item_proc,
self.dimensions,
tuple if self.as_tuple else list,
)
if self._against_native_enum:
super_rp = process
pattern = re.compile(r"^{(.*)}$")
def handle_raw_string(value):
inner = pattern.match(value).group(1)
return _split_enum_values(inner)
def process(value):
if value is None:
return value
# isinstance(value, util.string_types) is required to handle
# the case where a TypeDecorator for and Array of Enum is
# used like was required in sa < 1.3.17
return super_rp(
handle_raw_string(value)
if isinstance(value, util.string_types)
else value
)
return process
def _split_enum_values(array_string):
if '"' not in array_string:
# no escape char is present so it can just split on the comma
return array_string.split(",") if array_string else []
# handles quoted strings from:
# r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr'
# returns
# ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr']
text = array_string.replace(r"\"", "_$ESC_QUOTE$_")
text = text.replace(r"\\", "\\")
result = []
on_quotes = re.split(r'(")', text)
in_quotes = False
for tok in on_quotes:
if tok == '"':
in_quotes = not in_quotes
elif in_quotes:
result.append(tok.replace("_$ESC_QUOTE$_", '"'))
else:
result.extend(re.findall(r"([^\s,]+),?", tok))
return result

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,274 @@
# postgresql/on_conflict.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from . import ext
from ... import util
from ...sql import coercions
from ...sql import roles
from ...sql import schema
from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.base import ColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.expression import alias
from ...util.langhelpers import public_factory
__all__ = ("Insert", "insert")
class Insert(StandardInsert):
"""PostgreSQL-specific implementation of INSERT.
Adds methods for PG-specific syntaxes such as ON CONFLICT.
The :class:`_postgresql.Insert` object is created using the
:func:`sqlalchemy.dialects.postgresql.insert` function.
.. versionadded:: 1.1
"""
stringify_dialect = "postgresql"
inherit_cache = False
@util.memoized_property
def excluded(self):
"""Provide the ``excluded`` namespace for an ON CONFLICT statement
PG's ON CONFLICT clause allows reference to the row that would
be inserted, known as ``excluded``. This attribute provides
all columns in this row to be referenceable.
.. tip:: The :attr:`_postgresql.Insert.excluded` attribute is an
instance of :class:`_expression.ColumnCollection`, which provides
an interface the same as that of the :attr:`_schema.Table.c`
collection described at :ref:`metadata_tables_and_columns`.
With this collection, ordinary names are accessible like attributes
(e.g. ``stmt.excluded.some_column``), but special names and
dictionary method names should be accessed using indexed access,
such as ``stmt.excluded["column name"]`` or
``stmt.excluded["values"]``. See the docstring for
:class:`_expression.ColumnCollection` for further examples.
.. seealso::
:ref:`postgresql_insert_on_conflict` - example of how
to use :attr:`_expression.Insert.excluded`
"""
return alias(self.table, name="excluded").columns
_on_conflict_exclusive = _exclusive_against(
"_post_values_clause",
msgs={
"_post_values_clause": "This Insert construct already has "
"an ON CONFLICT clause established"
},
)
@_generative
@_on_conflict_exclusive
def on_conflict_do_update(
self,
constraint=None,
index_elements=None,
index_where=None,
set_=None,
where=None,
):
r"""
Specifies a DO UPDATE SET action for ON CONFLICT clause.
Either the ``constraint`` or ``index_elements`` argument is
required, but only one of these can be specified.
:param constraint:
The name of a unique or exclusion constraint on the table,
or the constraint object itself if it has a .name attribute.
:param index_elements:
A sequence consisting of string column names, :class:`_schema.Column`
objects, or other column expression objects that will be used
to infer a target index.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
:param set\_:
A dictionary or other mapping object
where the keys are either names of columns in the target table,
or :class:`_schema.Column` objects or other ORM-mapped columns
matching that of the target table, and expressions or literals
as values, specifying the ``SET`` actions to take.
.. versionadded:: 1.4 The
:paramref:`_postgresql.Insert.on_conflict_do_update.set_`
parameter supports :class:`_schema.Column` objects from the target
:class:`_schema.Table` as keys.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
e.g. those specified using :paramref:`_schema.Column.onupdate`.
These values will not be exercised for an ON CONFLICT style of
UPDATE, unless they are manually specified in the
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
Optional argument. If present, can be a literal SQL
string or an acceptable expression for a ``WHERE`` clause
that restricts the rows affected by ``DO UPDATE SET``. Rows
not meeting the ``WHERE`` condition will not be updated
(effectively a ``DO NOTHING`` for those rows).
.. versionadded:: 1.1
.. seealso::
:ref:`postgresql_insert_on_conflict`
"""
self._post_values_clause = OnConflictDoUpdate(
constraint, index_elements, index_where, set_, where
)
@_generative
@_on_conflict_exclusive
def on_conflict_do_nothing(
self, constraint=None, index_elements=None, index_where=None
):
"""
Specifies a DO NOTHING action for ON CONFLICT clause.
The ``constraint`` and ``index_elements`` arguments
are optional, but only one of these can be specified.
:param constraint:
The name of a unique or exclusion constraint on the table,
or the constraint object itself if it has a .name attribute.
:param index_elements:
A sequence consisting of string column names, :class:`_schema.Column`
objects, or other column expression objects that will be used
to infer a target index.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
.. versionadded:: 1.1
.. seealso::
:ref:`postgresql_insert_on_conflict`
"""
self._post_values_clause = OnConflictDoNothing(
constraint, index_elements, index_where
)
insert = public_factory(
Insert, ".dialects.postgresql.insert", ".dialects.postgresql.Insert"
)
class OnConflictClause(ClauseElement):
stringify_dialect = "postgresql"
def __init__(self, constraint=None, index_elements=None, index_where=None):
if constraint is not None:
if not isinstance(constraint, util.string_types) and isinstance(
constraint,
(schema.Index, schema.Constraint, ext.ExcludeConstraint),
):
constraint = getattr(constraint, "name") or constraint
if constraint is not None:
if index_elements is not None:
raise ValueError(
"'constraint' and 'index_elements' are mutually exclusive"
)
if isinstance(constraint, util.string_types):
self.constraint_target = constraint
self.inferred_target_elements = None
self.inferred_target_whereclause = None
elif isinstance(constraint, schema.Index):
index_elements = constraint.expressions
index_where = constraint.dialect_options["postgresql"].get(
"where"
)
elif isinstance(constraint, ext.ExcludeConstraint):
index_elements = constraint.columns
index_where = constraint.where
else:
index_elements = constraint.columns
index_where = constraint.dialect_options["postgresql"].get(
"where"
)
if index_elements is not None:
self.constraint_target = None
self.inferred_target_elements = index_elements
self.inferred_target_whereclause = index_where
elif constraint is None:
self.constraint_target = (
self.inferred_target_elements
) = self.inferred_target_whereclause = None
class OnConflictDoNothing(OnConflictClause):
__visit_name__ = "on_conflict_do_nothing"
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = "on_conflict_do_update"
def __init__(
self,
constraint=None,
index_elements=None,
index_where=None,
set_=None,
where=None,
):
super(OnConflictDoUpdate, self).__init__(
constraint=constraint,
index_elements=index_elements,
index_where=index_where,
)
if (
self.inferred_target_elements is None
and self.constraint_target is None
):
raise ValueError(
"Either constraint or index_elements, "
"but not both, must be specified unless DO NOTHING"
)
if isinstance(set_, dict):
if not set_:
raise ValueError("set parameter dictionary must not be empty")
elif isinstance(set_, ColumnCollection):
set_ = dict(set_)
else:
raise ValueError(
"set parameter must be a non-empty dictionary "
"or a ColumnCollection such as the `.c.` collection "
"of a Table object"
)
self.update_values_to_set = [
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
]
self.update_whereclause = where

View File

@@ -0,0 +1,277 @@
# postgresql/ext.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from .array import ARRAY
from ... import util
from ...sql import coercions
from ...sql import elements
from ...sql import expression
from ...sql import functions
from ...sql import roles
from ...sql import schema
from ...sql.schema import ColumnCollectionConstraint
class aggregate_order_by(expression.ColumnElement):
"""Represent a PostgreSQL aggregate order by expression.
E.g.::
from sqlalchemy.dialects.postgresql import aggregate_order_by
expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
stmt = select(expr)
would represent the expression::
SELECT array_agg(a ORDER BY b DESC) FROM table;
Similarly::
expr = func.string_agg(
table.c.a,
aggregate_order_by(literal_column("','"), table.c.a)
)
stmt = select(expr)
Would represent::
SELECT string_agg(a, ',' ORDER BY a) FROM table;
.. versionadded:: 1.1
.. versionchanged:: 1.2.13 - the ORDER BY argument may be multiple terms
.. seealso::
:class:`_functions.array_agg`
"""
__visit_name__ = "aggregate_order_by"
stringify_dialect = "postgresql"
inherit_cache = False
def __init__(self, target, *order_by):
self.target = coercions.expect(roles.ExpressionElementRole, target)
self.type = self.target.type
_lob = len(order_by)
if _lob == 0:
raise TypeError("at least one ORDER BY element is required")
elif _lob == 1:
self.order_by = coercions.expect(
roles.ExpressionElementRole, order_by[0]
)
else:
self.order_by = elements.ClauseList(
*order_by, _literal_as_text_role=roles.ExpressionElementRole
)
def self_group(self, against=None):
return self
def get_children(self, **kwargs):
return self.target, self.order_by
def _copy_internals(self, clone=elements._clone, **kw):
self.target = clone(self.target, **kw)
self.order_by = clone(self.order_by, **kw)
@property
def _from_objects(self):
return self.target._from_objects + self.order_by._from_objects
class ExcludeConstraint(ColumnCollectionConstraint):
"""A table-level EXCLUDE constraint.
Defines an EXCLUDE constraint as described in the `PostgreSQL
documentation`__.
__ https://www.postgresql.org/docs/current/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
""" # noqa
__visit_name__ = "exclude_constraint"
where = None
inherit_cache = False
create_drop_stringify_dialect = "postgresql"
@elements._document_text_coercion(
"where",
":class:`.ExcludeConstraint`",
":paramref:`.ExcludeConstraint.where`",
)
def __init__(self, *elements, **kw):
r"""
Create an :class:`.ExcludeConstraint` object.
E.g.::
const = ExcludeConstraint(
(Column('period'), '&&'),
(Column('group'), '='),
where=(Column('group') != 'some group'),
ops={'group': 'my_operator_class'}
)
The constraint is normally embedded into the :class:`_schema.Table`
construct
directly, or added later using :meth:`.append_constraint`::
some_table = Table(
'some_table', metadata,
Column('id', Integer, primary_key=True),
Column('period', TSRANGE()),
Column('group', String)
)
some_table.append_constraint(
ExcludeConstraint(
(some_table.c.period, '&&'),
(some_table.c.group, '='),
where=some_table.c.group != 'some group',
name='some_table_excl_const',
ops={'group': 'my_operator_class'}
)
)
:param \*elements:
A sequence of two tuples of the form ``(column, operator)`` where
"column" is a SQL expression element or a raw SQL string, most
typically a :class:`_schema.Column` object,
and "operator" is a string
containing the operator to use. In order to specify a column name
when a :class:`_schema.Column` object is not available,
while ensuring
that any necessary quoting rules take effect, an ad-hoc
:class:`_schema.Column` or :func:`_expression.column`
object should be
used.
:param name:
Optional, the in-database name of this constraint.
:param deferrable:
Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
issuing DDL for this constraint.
:param initially:
Optional string. If set, emit INITIALLY <value> when issuing DDL
for this constraint.
:param using:
Optional string. If set, emit USING <index_method> when issuing DDL
for this constraint. Defaults to 'gist'.
:param where:
Optional SQL expression construct or literal SQL string.
If set, emit WHERE <predicate> when issuing DDL
for this constraint.
:param ops:
Optional dictionary. Used to define operator classes for the
elements; works the same way as that of the
:ref:`postgresql_ops <postgresql_operator_classes>`
parameter specified to the :class:`_schema.Index` construct.
.. versionadded:: 1.3.21
.. seealso::
:ref:`postgresql_operator_classes` - general description of how
PostgreSQL operator classes are specified.
"""
columns = []
render_exprs = []
self.operators = {}
expressions, operators = zip(*elements)
for (expr, column, strname, add_element), operator in zip(
coercions.expect_col_expression_collection(
roles.DDLConstraintColumnRole, expressions
),
operators,
):
if add_element is not None:
columns.append(add_element)
name = column.name if column is not None else strname
if name is not None:
# backwards compat
self.operators[name] = operator
render_exprs.append((expr, name, operator))
self._render_exprs = render_exprs
ColumnCollectionConstraint.__init__(
self,
*columns,
name=kw.get("name"),
deferrable=kw.get("deferrable"),
initially=kw.get("initially")
)
self.using = kw.get("using", "gist")
where = kw.get("where")
if where is not None:
self.where = coercions.expect(roles.StatementOptionRole, where)
self.ops = kw.get("ops", {})
def _set_parent(self, table, **kw):
super(ExcludeConstraint, self)._set_parent(table)
self._render_exprs = [
(
expr if isinstance(expr, elements.ClauseElement) else colexpr,
name,
operator,
)
for (expr, name, operator), colexpr in util.zip_longest(
self._render_exprs, self.columns
)
]
def _copy(self, target_table=None, **kw):
elements = [
(
schema._copy_expression(expr, self.parent, target_table),
self.operators[expr.name],
)
for expr in self.columns
]
c = self.__class__(
*elements,
name=self.name,
deferrable=self.deferrable,
initially=self.initially,
where=self.where,
using=self.using
)
c.dispatch._update(self.dispatch)
return c
def array_agg(*arg, **kw):
"""PostgreSQL-specific form of :class:`_functions.array_agg`, ensures
return type is :class:`_postgresql.ARRAY` and not
the plain :class:`_types.ARRAY`, unless an explicit ``type_``
is passed.
.. versionadded:: 1.1
"""
kw["_default_array_type"] = ARRAY
return functions.func.array_agg(*arg, **kw)

View File

@@ -0,0 +1,455 @@
# postgresql/hstore.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
import re
from .array import ARRAY
from ... import types as sqltypes
from ... import util
from ...sql import functions as sqlfunc
from ...sql import operators
__all__ = ("HSTORE", "hstore")
idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
GETITEM = operators.custom_op(
"->",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
HAS_KEY = operators.custom_op(
"?",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
HAS_ALL = operators.custom_op(
"?&",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
HAS_ANY = operators.custom_op(
"?|",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
CONTAINS = operators.custom_op(
"@>",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
CONTAINED_BY = operators.custom_op(
"<@",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
"""Represent the PostgreSQL HSTORE type.
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
data_table = Table('data_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', HSTORE)
)
with engine.connect() as conn:
conn.execute(
data_table.insert(),
data = {"key1": "value1", "key2": "value2"}
)
:class:`.HSTORE` provides for a wide range of operations, including:
* Index operations::
data_table.c.data['some key'] == 'some value'
* Containment operations::
data_table.c.data.has_key('some key')
data_table.c.data.has_all(['one', 'two', 'three'])
* Concatenation::
data_table.c.data + {"k1": "v1"}
For a full list of special methods see
:class:`.HSTORE.comparator_factory`.
For usage with the SQLAlchemy ORM, it may be desirable to combine
the usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary
now part of the :mod:`sqlalchemy.ext.mutable`
extension. This extension will allow "in-place" changes to the
dictionary, e.g. addition of new keys or replacement/removal of existing
keys to/from the current dictionary, to produce events which will be
detected by the unit of work::
from sqlalchemy.ext.mutable import MutableDict
class MyClass(Base):
__tablename__ = 'data_table'
id = Column(Integer, primary_key=True)
data = Column(MutableDict.as_mutable(HSTORE))
my_object = session.query(MyClass).one()
# in-place mutation, requires Mutable extension
# in order for the ORM to detect
my_object.data['some_key'] = 'some value'
session.commit()
When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM
will not be alerted to any changes to the contents of an existing
dictionary, unless that dictionary value is re-assigned to the
HSTORE-attribute itself, thus generating a change event.
.. seealso::
:class:`.hstore` - render the PostgreSQL ``hstore()`` function.
"""
__visit_name__ = "HSTORE"
hashable = False
text_type = sqltypes.Text()
def __init__(self, text_type=None):
"""Construct a new :class:`.HSTORE`.
:param text_type: the type that should be used for indexed values.
Defaults to :class:`_types.Text`.
.. versionadded:: 1.1.0
"""
if text_type is not None:
self.text_type = text_type
class Comparator(
sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator
):
"""Define comparison operations for :class:`.HSTORE`."""
def has_key(self, other):
"""Boolean expression. Test for presence of a key. Note that the
key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
def has_all(self, other):
"""Boolean expression. Test for presence of all keys in jsonb"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
def has_any(self, other):
"""Boolean expression. Test for presence of any key in jsonb"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
def contains(self, other, **kwargs):
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def _setup_getitem(self, index):
return GETITEM, index, self.type.text_type
def defined(self, key):
"""Boolean expression. Test for presence of a non-NULL value for
the key. Note that the key may be a SQLA expression.
"""
return _HStoreDefinedFunction(self.expr, key)
def delete(self, key):
"""HStore expression. Returns the contents of this hstore with the
given key deleted. Note that the key may be a SQLA expression.
"""
if isinstance(key, dict):
key = _serialize_hstore(key)
return _HStoreDeleteFunction(self.expr, key)
def slice(self, array):
"""HStore expression. Returns a subset of an hstore defined by
array of keys.
"""
return _HStoreSliceFunction(self.expr, array)
def keys(self):
"""Text array expression. Returns array of keys."""
return _HStoreKeysFunction(self.expr)
def vals(self):
"""Text array expression. Returns array of values."""
return _HStoreValsFunction(self.expr)
def array(self):
"""Text array expression. Returns array of alternating keys and
values.
"""
return _HStoreArrayFunction(self.expr)
def matrix(self):
"""Text array expression. Returns array of [key, value] pairs."""
return _HStoreMatrixFunction(self.expr)
comparator_factory = Comparator
def bind_processor(self, dialect):
if util.py2k:
encoding = dialect.encoding
def process(value):
if isinstance(value, dict):
return _serialize_hstore(value).encode(encoding)
else:
return value
else:
def process(value):
if isinstance(value, dict):
return _serialize_hstore(value)
else:
return value
return process
def result_processor(self, dialect, coltype):
if util.py2k:
encoding = dialect.encoding
def process(value):
if value is not None:
return _parse_hstore(value.decode(encoding))
else:
return value
else:
def process(value):
if value is not None:
return _parse_hstore(value)
else:
return value
return process
class hstore(sqlfunc.GenericFunction):
"""Construct an hstore value within a SQL expression using the
PostgreSQL ``hstore()`` function.
The :class:`.hstore` function accepts one or two arguments as described
in the PostgreSQL documentation.
E.g.::
from sqlalchemy.dialects.postgresql import array, hstore
select(hstore('key1', 'value1'))
select(
hstore(
array(['key1', 'key2', 'key3']),
array(['value1', 'value2', 'value3'])
)
)
.. seealso::
:class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype.
"""
type = HSTORE
name = "hstore"
inherit_cache = True
class _HStoreDefinedFunction(sqlfunc.GenericFunction):
type = sqltypes.Boolean
name = "defined"
inherit_cache = True
class _HStoreDeleteFunction(sqlfunc.GenericFunction):
type = HSTORE
name = "delete"
inherit_cache = True
class _HStoreSliceFunction(sqlfunc.GenericFunction):
type = HSTORE
name = "slice"
inherit_cache = True
class _HStoreKeysFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = "akeys"
inherit_cache = True
class _HStoreValsFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = "avals"
inherit_cache = True
class _HStoreArrayFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = "hstore_to_array"
inherit_cache = True
class _HStoreMatrixFunction(sqlfunc.GenericFunction):
type = ARRAY(sqltypes.Text)
name = "hstore_to_matrix"
inherit_cache = True
#
# parsing. note that none of this is used with the psycopg2 backend,
# which provides its own native extensions.
#
# My best guess at the parsing rules of hstore literals, since no formal
# grammar is given. This is mostly reverse engineered from PG's input parser
# behavior.
HSTORE_PAIR_RE = re.compile(
r"""
(
"(?P<key> (\\ . | [^"])* )" # Quoted key
)
[ ]* => [ ]* # Pair operator, optional adjoining whitespace
(
(?P<value_null> NULL ) # NULL value
| "(?P<value> (\\ . | [^"])* )" # Quoted value
)
""",
re.VERBOSE,
)
HSTORE_DELIMITER_RE = re.compile(
r"""
[ ]* , [ ]*
""",
re.VERBOSE,
)
def _parse_error(hstore_str, pos):
"""format an unmarshalling error."""
ctx = 20
hslen = len(hstore_str)
parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)]
residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)]
if len(parsed_tail) > ctx:
parsed_tail = "[...]" + parsed_tail[1:]
if len(residual) > ctx:
residual = residual[:-1] + "[...]"
return "After %r, could not parse residual at position %d: %r" % (
parsed_tail,
pos,
residual,
)
def _parse_hstore(hstore_str):
"""Parse an hstore from its literal string representation.
Attempts to approximate PG's hstore input parsing rules as closely as
possible. Although currently this is not strictly necessary, since the
current implementation of hstore's output syntax is stricter than what it
accepts as input, the documentation makes no guarantees that will always
be the case.
"""
result = {}
pos = 0
pair_match = HSTORE_PAIR_RE.match(hstore_str)
while pair_match is not None:
key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\")
if pair_match.group("value_null"):
value = None
else:
value = (
pair_match.group("value")
.replace(r"\"", '"')
.replace("\\\\", "\\")
)
result[key] = value
pos += pair_match.end()
delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:])
if delim_match is not None:
pos += delim_match.end()
pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:])
if pos != len(hstore_str):
raise ValueError(_parse_error(hstore_str, pos))
return result
def _serialize_hstore(val):
"""Serialize a dictionary into an hstore literal. Keys and values must
both be strings (except None for values).
"""
def esc(s, position):
if position == "value" and s is None:
return "NULL"
elif isinstance(s, util.string_types):
return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"")
else:
raise ValueError(
"%r in %s position is not a string." % (s, position)
)
return ", ".join(
"%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items()
)

View File

@@ -0,0 +1,327 @@
# postgresql/json.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import absolute_import
from ... import types as sqltypes
from ... import util
from ...sql import operators
__all__ = ("JSON", "JSONB")
idx_precedence = operators._PRECEDENCE[operators.json_getitem_op]
ASTEXT = operators.custom_op(
"->>",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
JSONPATH_ASTEXT = operators.custom_op(
"#>>",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
HAS_KEY = operators.custom_op(
"?",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
HAS_ALL = operators.custom_op(
"?&",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
HAS_ANY = operators.custom_op(
"?|",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
CONTAINS = operators.custom_op(
"@>",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
CONTAINED_BY = operators.custom_op(
"<@",
precedence=idx_precedence,
natural_self_precedent=True,
eager_grouping=True,
)
class JSONPathType(sqltypes.JSON.JSONPathType):
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
assert isinstance(value, util.collections_abc.Sequence)
tokens = [util.text_type(elem) for elem in value]
value = "{%s}" % (", ".join(tokens))
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
assert isinstance(value, util.collections_abc.Sequence)
tokens = [util.text_type(elem) for elem in value]
value = "{%s}" % (", ".join(tokens))
if super_proc:
value = super_proc(value)
return value
return process
class JSON(sqltypes.JSON):
"""Represent the PostgreSQL JSON type.
:class:`_postgresql.JSON` is used automatically whenever the base
:class:`_types.JSON` datatype is used against a PostgreSQL backend,
however base :class:`_types.JSON` datatype does not provide Python
accessors for PostgreSQL-specific comparison methods such as
:meth:`_postgresql.JSON.Comparator.astext`; additionally, to use
PostgreSQL ``JSONB``, the :class:`_postgresql.JSONB` datatype should
be used explicitly.
.. seealso::
:class:`_types.JSON` - main documentation for the generic
cross-platform JSON datatype.
The operators provided by the PostgreSQL version of :class:`_types.JSON`
include:
* Index operations (the ``->`` operator)::
data_table.c.data['some key']
data_table.c.data[5]
* Index operations returning text (the ``->>`` operator)::
data_table.c.data['some key'].astext == 'some value'
Note that equivalent functionality is available via the
:attr:`.JSON.Comparator.as_string` accessor.
* Index operations with CAST
(equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
data_table.c.data['some key'].astext.cast(Integer) == 5
Note that equivalent functionality is available via the
:attr:`.JSON.Comparator.as_integer` and similar accessors.
* Path index operations (the ``#>`` operator)::
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
* Path index operations returning text (the ``#>>`` operator)::
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value'
.. versionchanged:: 1.1 The :meth:`_expression.ColumnElement.cast`
operator on
JSON objects now requires that the :attr:`.JSON.Comparator.astext`
modifier be called explicitly, if the cast works only from a textual
string.
Index operations return an expression object whose type defaults to
:class:`_types.JSON` by default,
so that further JSON-oriented instructions
may be called upon the result type.
Custom serializers and deserializers are specified at the dialect level,
that is using :func:`_sa.create_engine`. The reason for this is that when
using psycopg2, the DBAPI only allows serializers at the per-cursor
or per-connection level. E.g.::
engine = create_engine("postgresql://scott:tiger@localhost/test",
json_serializer=my_serialize_fn,
json_deserializer=my_deserialize_fn
)
When using the psycopg2 dialect, the json_deserializer is registered
against the database using ``psycopg2.extras.register_default_json``.
.. seealso::
:class:`_types.JSON` - Core level JSON type
:class:`_postgresql.JSONB`
.. versionchanged:: 1.1 :class:`_postgresql.JSON` is now a PostgreSQL-
specific specialization of the new :class:`_types.JSON` type.
""" # noqa
astext_type = sqltypes.Text()
def __init__(self, none_as_null=False, astext_type=None):
"""Construct a :class:`_types.JSON` type.
:param none_as_null: if True, persist the value ``None`` as a
SQL NULL value, not the JSON encoding of ``null``. Note that
when this flag is False, the :func:`.null` construct can still
be used to persist a NULL value::
from sqlalchemy import null
conn.execute(table.insert(), data=null())
.. versionchanged:: 0.9.8 - Added ``none_as_null``, and :func:`.null`
is now supported in order to persist a NULL value.
.. seealso::
:attr:`_types.JSON.NULL`
:param astext_type: the type to use for the
:attr:`.JSON.Comparator.astext`
accessor on indexed attributes. Defaults to :class:`_types.Text`.
.. versionadded:: 1.1
"""
super(JSON, self).__init__(none_as_null=none_as_null)
if astext_type is not None:
self.astext_type = astext_type
class Comparator(sqltypes.JSON.Comparator):
"""Define comparison operations for :class:`_types.JSON`."""
@property
def astext(self):
"""On an indexed expression, use the "astext" (e.g. "->>")
conversion when rendered in SQL.
E.g.::
select(data_table.c.data['some key'].astext)
.. seealso::
:meth:`_expression.ColumnElement.cast`
"""
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
return self.expr.left.operate(
JSONPATH_ASTEXT,
self.expr.right,
result_type=self.type.astext_type,
)
else:
return self.expr.left.operate(
ASTEXT, self.expr.right, result_type=self.type.astext_type
)
comparator_factory = Comparator
class JSONB(JSON):
"""Represent the PostgreSQL JSONB type.
The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
e.g.::
data_table = Table('data_table', metadata,
Column('id', Integer, primary_key=True),
Column('data', JSONB)
)
with engine.connect() as conn:
conn.execute(
data_table.insert(),
data = {"key1": "value1", "key2": "value2"}
)
The :class:`_postgresql.JSONB` type includes all operations provided by
:class:`_types.JSON`, including the same behaviors for indexing
operations.
It also adds additional operators specific to JSONB, including
:meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`,
:meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`,
and :meth:`.JSONB.Comparator.contained_by`.
Like the :class:`_types.JSON` type, the :class:`_postgresql.JSONB`
type does not detect
in-place changes when used with the ORM, unless the
:mod:`sqlalchemy.ext.mutable` extension is used.
Custom serializers and deserializers
are shared with the :class:`_types.JSON` class,
using the ``json_serializer``
and ``json_deserializer`` keyword arguments. These must be specified
at the dialect level using :func:`_sa.create_engine`. When using
psycopg2, the serializers are associated with the jsonb type using
``psycopg2.extras.register_default_jsonb`` on a per-connection basis,
in the same way that ``psycopg2.extras.register_default_json`` is used
to register these handlers with the json type.
.. versionadded:: 0.9.7
.. seealso::
:class:`_types.JSON`
"""
__visit_name__ = "JSONB"
class Comparator(JSON.Comparator):
"""Define comparison operations for :class:`_types.JSON`."""
def has_key(self, other):
"""Boolean expression. Test for presence of a key. Note that the
key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
def has_all(self, other):
"""Boolean expression. Test for presence of all keys in jsonb"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
def has_any(self, other):
"""Boolean expression. Test for presence of any key in jsonb"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
def contains(self, other, **kwargs):
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression.
"""
return self.operate(
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
comparator_factory = Comparator

View File

@@ -0,0 +1,594 @@
# postgresql/pg8000.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS
# file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: postgresql+pg8000
:name: pg8000
:dbapi: pg8000
:connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pypi.org/project/pg8000/
.. versionchanged:: 1.4 The pg8000 dialect has been updated for version
1.16.6 and higher, and is again part of SQLAlchemy's continuous integration
with full feature support.
.. _pg8000_unicode:
Unicode
-------
pg8000 will encode / decode string values between it and the server using the
PostgreSQL ``client_encoding`` parameter; by default this is the value in
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
Typically, this can be changed to ``utf-8``, as a more useful default::
#client_encoding = sql_ascii # actually, defaults to database
# encoding
client_encoding = utf8
The ``client_encoding`` can be overridden for a session by executing the SQL:
SET CLIENT_ENCODING TO 'utf8';
SQLAlchemy will execute this SQL on all new connections based on the value
passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
engine = create_engine(
"postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
.. _pg8000_ssl:
SSL Connections
---------------
pg8000 accepts a Python ``SSLContext`` object which may be specified using the
:paramref:`_sa.create_engine.connect_args` dictionary::
import ssl
ssl_context = ssl.create_default_context()
engine = sa.create_engine(
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
connect_args={"ssl_context": ssl_context},
)
If the server uses an automatically-generated certificate that is self-signed
or does not match the host name (as seen from the client), it may also be
necessary to disable hostname checking::
import ssl
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
engine = sa.create_engine(
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
connect_args={"ssl_context": ssl_context},
)
.. _pg8000_isolation_level:
pg8000 Transaction Isolation Level
-------------------------------------
The pg8000 dialect offers the same isolation level settings as that
of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
* ``READ COMMITTED``
* ``READ UNCOMMITTED``
* ``REPEATABLE READ``
* ``SERIALIZABLE``
* ``AUTOCOMMIT``
.. seealso::
:ref:`postgresql_isolation_level`
:ref:`psycopg2_isolation_level`
""" # noqa
import decimal
import re
from uuid import UUID as _python_UUID
from .array import ARRAY as PGARRAY
from .base import _ColonCast
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
from .base import ENUM
from .base import INTERVAL
from .base import PGCompiler
from .base import PGDialect
from .base import PGExecutionContext
from .base import PGIdentifierPreparer
from .base import UUID
from .json import JSON
from .json import JSONB
from .json import JSONPathType
from ... import exc
from ... import processors
from ... import types as sqltypes
from ... import util
from ...sql.elements import quoted_name
class _PGNumeric(sqltypes.Numeric):
def result_processor(self, dialect, coltype):
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal, self._effective_decimal_return_scale
)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# pg8000 returns Decimal natively for 1700
return None
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype
)
else:
if coltype in _FLOAT_TYPES:
# pg8000 returns float natively for 701
return None
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype
)
class _PGNumericNoBind(_PGNumeric):
def bind_processor(self, dialect):
return None
class _PGJSON(JSON):
def result_processor(self, dialect, coltype):
return None
def get_dbapi_type(self, dbapi):
return dbapi.JSON
class _PGJSONB(JSONB):
def result_processor(self, dialect, coltype):
return None
def get_dbapi_type(self, dbapi):
return dbapi.JSONB
class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
def get_dbapi_type(self, dbapi):
raise NotImplementedError("should not be here")
class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
def get_dbapi_type(self, dbapi):
return dbapi.INTEGER
class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
def get_dbapi_type(self, dbapi):
return dbapi.STRING
class _PGJSONPathType(JSONPathType):
def get_dbapi_type(self, dbapi):
return 1009
class _PGUUID(UUID):
def bind_processor(self, dialect):
if not self.as_uuid:
def process(value):
if value is not None:
value = _python_UUID(value)
return value
return process
def result_processor(self, dialect, coltype):
if not self.as_uuid:
def process(value):
if value is not None:
value = str(value)
return value
return process
class _PGEnum(ENUM):
def get_dbapi_type(self, dbapi):
return dbapi.UNKNOWN
class _PGInterval(INTERVAL):
def get_dbapi_type(self, dbapi):
return dbapi.INTERVAL
@classmethod
def adapt_emulated_to_native(cls, interval, **kw):
return _PGInterval(precision=interval.second_precision)
class _PGTimeStamp(sqltypes.DateTime):
def get_dbapi_type(self, dbapi):
if self.timezone:
# TIMESTAMPTZOID
return 1184
else:
# TIMESTAMPOID
return 1114
class _PGTime(sqltypes.Time):
def get_dbapi_type(self, dbapi):
return dbapi.TIME
class _PGInteger(sqltypes.Integer):
def get_dbapi_type(self, dbapi):
return dbapi.INTEGER
class _PGSmallInteger(sqltypes.SmallInteger):
def get_dbapi_type(self, dbapi):
return dbapi.INTEGER
class _PGNullType(sqltypes.NullType):
def get_dbapi_type(self, dbapi):
return dbapi.NULLTYPE
class _PGBigInteger(sqltypes.BigInteger):
def get_dbapi_type(self, dbapi):
return dbapi.BIGINTEGER
class _PGBoolean(sqltypes.Boolean):
def get_dbapi_type(self, dbapi):
return dbapi.BOOLEAN
class _PGARRAY(PGARRAY):
def bind_expression(self, bindvalue):
return _ColonCast(bindvalue, self)
_server_side_id = util.counter()
class PGExecutionContext_pg8000(PGExecutionContext):
def create_server_side_cursor(self):
ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
return ServerSideCursor(self._dbapi_connection.cursor(), ident)
def pre_exec(self):
if not self.compiled:
return
class ServerSideCursor:
server_side = True
def __init__(self, cursor, ident):
self.ident = ident
self.cursor = cursor
@property
def connection(self):
return self.cursor.connection
@property
def rowcount(self):
return self.cursor.rowcount
@property
def description(self):
return self.cursor.description
def execute(self, operation, args=(), stream=None):
op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation
self.cursor.execute(op, args, stream=stream)
return self
def executemany(self, operation, param_sets):
self.cursor.executemany(operation, param_sets)
return self
def fetchone(self):
self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident)
return self.cursor.fetchone()
def fetchmany(self, num=None):
if num is None:
return self.fetchall()
else:
self.cursor.execute(
"FETCH FORWARD " + str(int(num)) + " FROM " + self.ident
)
return self.cursor.fetchall()
def fetchall(self):
self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident)
return self.cursor.fetchall()
def close(self):
self.cursor.execute("CLOSE " + self.ident)
self.cursor.close()
def setinputsizes(self, *sizes):
self.cursor.setinputsizes(*sizes)
def setoutputsize(self, size, column=None):
pass
class PGCompiler_pg8000(PGCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return (
self.process(binary.left, **kw)
+ " %% "
+ self.process(binary.right, **kw)
)
class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
def __init__(self, *args, **kwargs):
PGIdentifierPreparer.__init__(self, *args, **kwargs)
self._double_percents = False
class PGDialect_pg8000(PGDialect):
driver = "pg8000"
supports_statement_cache = True
supports_unicode_statements = True
supports_unicode_binds = True
default_paramstyle = "format"
supports_sane_multi_rowcount = True
execution_ctx_cls = PGExecutionContext_pg8000
statement_compiler = PGCompiler_pg8000
preparer = PGIdentifierPreparer_pg8000
supports_server_side_cursors = True
use_setinputsizes = True
# reversed as of pg8000 1.16.6. 1.16.5 and lower
# are no longer compatible
description_encoding = None
# description_encoding = "use_encoding"
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric: _PGNumericNoBind,
sqltypes.Float: _PGNumeric,
sqltypes.JSON: _PGJSON,
sqltypes.Boolean: _PGBoolean,
sqltypes.NullType: _PGNullType,
JSONB: _PGJSONB,
sqltypes.JSON.JSONPathType: _PGJSONPathType,
sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
UUID: _PGUUID,
sqltypes.Interval: _PGInterval,
INTERVAL: _PGInterval,
sqltypes.DateTime: _PGTimeStamp,
sqltypes.Time: _PGTime,
sqltypes.Integer: _PGInteger,
sqltypes.SmallInteger: _PGSmallInteger,
sqltypes.BigInteger: _PGBigInteger,
sqltypes.Enum: _PGEnum,
sqltypes.ARRAY: _PGARRAY,
},
)
def __init__(self, client_encoding=None, **kwargs):
PGDialect.__init__(self, **kwargs)
self.client_encoding = client_encoding
if self._dbapi_version < (1, 16, 6):
raise NotImplementedError("pg8000 1.16.6 or greater is required")
@util.memoized_property
def _dbapi_version(self):
if self.dbapi and hasattr(self.dbapi, "__version__"):
return tuple(
[
int(x)
for x in re.findall(
r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
)
]
)
else:
return (99, 99, 99)
@classmethod
def dbapi(cls):
return __import__("pg8000")
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
if "port" in opts:
opts["port"] = int(opts["port"])
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.InterfaceError) and "network error" in str(
e
):
# new as of pg8000 1.19.0 for broken connections
return True
# connection was closed normally
return "connection is closed" in str(e)
def set_isolation_level(self, connection, level):
level = level.replace("_", " ")
# adjust for ConnectionFairy possibly being present
if hasattr(connection, "dbapi_connection"):
connection = connection.dbapi_connection
if level == "AUTOCOMMIT":
connection.autocommit = True
elif level in self._isolation_lookup:
connection.autocommit = False
cursor = connection.cursor()
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION "
"ISOLATION LEVEL %s" % level
)
cursor.execute("COMMIT")
cursor.close()
else:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for %s are %s or AUTOCOMMIT"
% (level, self.name, ", ".join(self._isolation_lookup))
)
def set_readonly(self, connection, value):
cursor = connection.cursor()
try:
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION %s"
% ("READ ONLY" if value else "READ WRITE")
)
cursor.execute("COMMIT")
finally:
cursor.close()
def get_readonly(self, connection):
cursor = connection.cursor()
try:
cursor.execute("show transaction_read_only")
val = cursor.fetchone()[0]
finally:
cursor.close()
return val == "on"
def set_deferrable(self, connection, value):
cursor = connection.cursor()
try:
cursor.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION %s"
% ("DEFERRABLE" if value else "NOT DEFERRABLE")
)
cursor.execute("COMMIT")
finally:
cursor.close()
def get_deferrable(self, connection):
cursor = connection.cursor()
try:
cursor.execute("show transaction_deferrable")
val = cursor.fetchone()[0]
finally:
cursor.close()
return val == "on"
def set_client_encoding(self, connection, client_encoding):
# adjust for ConnectionFairy possibly being present
if hasattr(connection, "dbapi_connection"):
connection = connection.dbapi_connection
cursor = connection.cursor()
cursor.execute("SET CLIENT_ENCODING TO '" + client_encoding + "'")
cursor.execute("COMMIT")
cursor.close()
def do_set_input_sizes(self, cursor, list_of_tuples, context):
if self.positional:
cursor.setinputsizes(
*[dbtype for key, dbtype, sqltype in list_of_tuples]
)
else:
cursor.setinputsizes(
**{
key: dbtype
for key, dbtype, sqltype in list_of_tuples
if dbtype
}
)
def do_begin_twophase(self, connection, xid):
connection.connection.tpc_begin((0, xid, ""))
def do_prepare_twophase(self, connection, xid):
connection.connection.tpc_prepare()
def do_rollback_twophase(
self, connection, xid, is_prepared=True, recover=False
):
connection.connection.tpc_rollback((0, xid, ""))
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
connection.connection.tpc_commit((0, xid, ""))
def do_recover_twophase(self, connection):
return [row[1] for row in connection.connection.tpc_recover()]
def on_connect(self):
fns = []
def on_connect(conn):
conn.py_types[quoted_name] = conn.py_types[util.text_type]
fns.append(on_connect)
if self.client_encoding is not None:
def on_connect(conn):
self.set_client_encoding(conn, self.client_encoding)
fns.append(on_connect)
if self.isolation_level is not None:
def on_connect(conn):
self.set_isolation_level(conn, self.isolation_level)
fns.append(on_connect)
if self._json_deserializer:
def on_connect(conn):
# json
conn.register_in_adapter(114, self._json_deserializer)
# jsonb
conn.register_in_adapter(3802, self._json_deserializer)
fns.append(on_connect)
if len(fns) > 0:
def on_connect(conn):
for fn in fns:
fn(conn)
return on_connect
else:
return None
dialect = PGDialect_pg8000

View File

@@ -0,0 +1,124 @@
import time
from ... import exc
from ... import inspect
from ... import text
from ...testing import warn_test_suite
from ...testing.provision import create_db
from ...testing.provision import drop_all_schema_objects_post_tables
from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import log
from ...testing.provision import prepare_for_drop_tables
from ...testing.provision import set_default_schema_on_connection
from ...testing.provision import temp_table_keyword_args
@create_db.for_db("postgresql")
def _pg_create_db(cfg, eng, ident):
template_db = cfg.options.postgresql_templatedb
with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
if not template_db:
template_db = conn.exec_driver_sql(
"select current_database()"
).scalar()
attempt = 0
while True:
try:
conn.exec_driver_sql(
"CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
)
except exc.OperationalError as err:
attempt += 1
if attempt >= 3:
raise
if "accessed by other users" in str(err):
log.info(
"Waiting to create %s, URI %r, "
"template DB %s is in use sleeping for .5",
ident,
eng.url,
template_db,
)
time.sleep(0.5)
except:
raise
else:
break
@drop_db.for_db("postgresql")
def _pg_drop_db(cfg, eng, ident):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
with conn.begin():
conn.execute(
text(
"select pg_terminate_backend(pid) from pg_stat_activity "
"where usename=current_user and pid != pg_backend_pid() "
"and datname=:dname"
),
dict(dname=ident),
)
conn.exec_driver_sql("DROP DATABASE %s" % ident)
@temp_table_keyword_args.for_db("postgresql")
def _postgresql_temp_table_keyword_args(cfg, eng):
return {"prefixes": ["TEMPORARY"]}
@set_default_schema_on_connection.for_db("postgresql")
def _postgresql_set_default_schema_on_connection(
cfg, dbapi_connection, schema_name
):
existing_autocommit = dbapi_connection.autocommit
dbapi_connection.autocommit = True
cursor = dbapi_connection.cursor()
cursor.execute("SET SESSION search_path='%s'" % schema_name)
cursor.close()
dbapi_connection.autocommit = existing_autocommit
@drop_all_schema_objects_pre_tables.for_db("postgresql")
def drop_all_schema_objects_pre_tables(cfg, eng):
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
for xid in conn.execute("select gid from pg_prepared_xacts").scalars():
conn.execute("ROLLBACK PREPARED '%s'" % xid)
@drop_all_schema_objects_post_tables.for_db("postgresql")
def drop_all_schema_objects_post_tables(cfg, eng):
from sqlalchemy.dialects import postgresql
inspector = inspect(eng)
with eng.begin() as conn:
for enum in inspector.get_enums("*"):
conn.execute(
postgresql.DropEnumType(
postgresql.ENUM(name=enum["name"], schema=enum["schema"])
)
)
@prepare_for_drop_tables.for_db("postgresql")
def prepare_for_drop_tables(config, connection):
"""Ensure there are no locks on the current username/database."""
result = connection.exec_driver_sql(
"select pid, state, wait_event_type, query "
# "select pg_terminate_backend(pid), state, wait_event_type "
"from pg_stat_activity where "
"usename=current_user "
"and datname=current_database() and state='idle in transaction' "
"and pid != pg_backend_pid()"
)
rows = result.all() # noqa
if rows:
warn_test_suite(
"PostgreSQL may not be able to DROP tables due to "
"idle in transaction: %s"
% ("; ".join(row._mapping["query"] for row in rows))
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,60 @@
# testing/engines.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: postgresql+psycopg2cffi
:name: psycopg2cffi
:dbapi: psycopg2cffi
:connectstring: postgresql+psycopg2cffi://user:password@host:port/dbname[?key=value&key=value...]
:url: https://pypi.org/project/psycopg2cffi/
``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C
layer. This makes it suitable for use in e.g. PyPy. Documentation
is as per ``psycopg2``.
.. versionadded:: 1.0.0
.. seealso::
:mod:`sqlalchemy.dialects.postgresql.psycopg2`
""" # noqa
from .psycopg2 import PGDialect_psycopg2
class PGDialect_psycopg2cffi(PGDialect_psycopg2):
driver = "psycopg2cffi"
supports_unicode_statements = True
supports_statement_cache = True
# psycopg2cffi's first release is 2.5.0, but reports
# __version__ as 2.4.4. Subsequent releases seem to have
# fixed this.
FEATURE_VERSION_MAP = dict(
native_json=(2, 4, 4),
native_jsonb=(2, 7, 1),
sane_multi_rowcount=(2, 4, 4),
array_oid=(2, 4, 4),
hstore_adapter=(2, 4, 4),
)
@classmethod
def dbapi(cls):
return __import__("psycopg2cffi")
@classmethod
def _psycopg2_extensions(cls):
root = __import__("psycopg2cffi", fromlist=["extensions"])
return root.extensions
@classmethod
def _psycopg2_extras(cls):
root = __import__("psycopg2cffi", fromlist=["extras"])
return root.extras
dialect = PGDialect_psycopg2cffi

View File

@@ -0,0 +1,278 @@
# postgresql/pygresql.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: postgresql+pygresql
:name: pygresql
:dbapi: pgdb
:connectstring: postgresql+pygresql://user:password@host:port/dbname[?key=value&key=value...]
:url: https://www.pygresql.org/
.. note::
The pygresql dialect is **not tested as part of SQLAlchemy's continuous
integration** and may have unresolved issues. The recommended PostgreSQL
dialect is psycopg2.
.. deprecated:: 1.4 The pygresql DBAPI is deprecated and will be removed
in a future version. Please use one of the supported DBAPIs to
connect to PostgreSQL.
""" # noqa
import decimal
import re
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
from .base import PGCompiler
from .base import PGDialect
from .base import PGIdentifierPreparer
from .base import UUID
from .hstore import HSTORE
from .json import JSON
from .json import JSONB
from ... import exc
from ... import processors
from ... import util
from ...sql.elements import Null
from ...types import JSON as Json
from ...types import Numeric
class _PGNumeric(Numeric):
def bind_processor(self, dialect):
return None
def result_processor(self, dialect, coltype):
if not isinstance(coltype, int):
coltype = coltype.oid
if self.asdecimal:
if coltype in _FLOAT_TYPES:
return processors.to_decimal_processor_factory(
decimal.Decimal, self._effective_decimal_return_scale
)
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
# PyGreSQL returns Decimal natively for 1700 (numeric)
return None
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype
)
else:
if coltype in _FLOAT_TYPES:
# PyGreSQL returns float natively for 701 (float8)
return None
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
return processors.to_float
else:
raise exc.InvalidRequestError(
"Unknown PG numeric type: %d" % coltype
)
class _PGHStore(HSTORE):
def bind_processor(self, dialect):
if not dialect.has_native_hstore:
return super(_PGHStore, self).bind_processor(dialect)
hstore = dialect.dbapi.Hstore
def process(value):
if isinstance(value, dict):
return hstore(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_hstore:
return super(_PGHStore, self).result_processor(dialect, coltype)
class _PGJSON(JSON):
def bind_processor(self, dialect):
if not dialect.has_native_json:
return super(_PGJSON, self).bind_processor(dialect)
json = dialect.dbapi.Json
def process(value):
if value is self.NULL:
value = None
elif isinstance(value, Null) or (
value is None and self.none_as_null
):
return None
if value is None or isinstance(value, (dict, list)):
return json(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_json:
return super(_PGJSON, self).result_processor(dialect, coltype)
class _PGJSONB(JSONB):
def bind_processor(self, dialect):
if not dialect.has_native_json:
return super(_PGJSONB, self).bind_processor(dialect)
json = dialect.dbapi.Json
def process(value):
if value is self.NULL:
value = None
elif isinstance(value, Null) or (
value is None and self.none_as_null
):
return None
if value is None or isinstance(value, (dict, list)):
return json(value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_json:
return super(_PGJSONB, self).result_processor(dialect, coltype)
class _PGUUID(UUID):
def bind_processor(self, dialect):
if not dialect.has_native_uuid:
return super(_PGUUID, self).bind_processor(dialect)
uuid = dialect.dbapi.Uuid
def process(value):
if value is None:
return None
if isinstance(value, (str, bytes)):
if len(value) == 16:
return uuid(bytes=value)
return uuid(value)
if isinstance(value, int):
return uuid(int=value)
return value
return process
def result_processor(self, dialect, coltype):
if not dialect.has_native_uuid:
return super(_PGUUID, self).result_processor(dialect, coltype)
if not self.as_uuid:
def process(value):
if value is not None:
return str(value)
return process
class _PGCompiler(PGCompiler):
def visit_mod_binary(self, binary, operator, **kw):
return (
self.process(binary.left, **kw)
+ " %% "
+ self.process(binary.right, **kw)
)
def post_process_text(self, text):
return text.replace("%", "%%")
class _PGIdentifierPreparer(PGIdentifierPreparer):
def _escape_identifier(self, value):
value = value.replace(self.escape_quote, self.escape_to_quote)
return value.replace("%", "%%")
class PGDialect_pygresql(PGDialect):
driver = "pygresql"
supports_statement_cache = True
statement_compiler = _PGCompiler
preparer = _PGIdentifierPreparer
@classmethod
def dbapi(cls):
import pgdb
util.warn_deprecated(
"The pygresql DBAPI is deprecated and will be removed "
"in a future version. Please use one of the supported DBAPIs to "
"connect to PostgreSQL.",
version="1.4",
)
return pgdb
colspecs = util.update_copy(
PGDialect.colspecs,
{
Numeric: _PGNumeric,
HSTORE: _PGHStore,
Json: _PGJSON,
JSON: _PGJSON,
JSONB: _PGJSONB,
UUID: _PGUUID,
},
)
def __init__(self, **kwargs):
super(PGDialect_pygresql, self).__init__(**kwargs)
try:
version = self.dbapi.version
m = re.match(r"(\d+)\.(\d+)", version)
version = (int(m.group(1)), int(m.group(2)))
except (AttributeError, ValueError, TypeError):
version = (0, 0)
self.dbapi_version = version
if version < (5, 0):
has_native_hstore = has_native_json = has_native_uuid = False
if version != (0, 0):
util.warn(
"PyGreSQL is only fully supported by SQLAlchemy"
" since version 5.0."
)
else:
self.supports_unicode_statements = True
self.supports_unicode_binds = True
has_native_hstore = has_native_json = has_native_uuid = True
self.has_native_hstore = has_native_hstore
self.has_native_json = has_native_json
self.has_native_uuid = has_native_uuid
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
if "port" in opts:
opts["host"] = "%s:%s" % (
opts.get("host", "").rsplit(":", 1)[0],
opts.pop("port"),
)
opts.update(url.query)
return [], opts
def is_disconnect(self, e, connection, cursor):
if isinstance(e, self.dbapi.Error):
if not connection:
return False
try:
connection = connection.connection
except AttributeError:
pass
else:
if not connection:
return False
try:
return connection.closed
except AttributeError: # PyGreSQL < 5.0
return connection._cnx is None
return False
dialect = PGDialect_pygresql

View File

@@ -0,0 +1,126 @@
# postgresql/pypostgresql.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: postgresql+pypostgresql
:name: py-postgresql
:dbapi: pypostgresql
:connectstring: postgresql+pypostgresql://user:password@host:port/dbname[?key=value&key=value...]
:url: https://python.projects.pgfoundry.org/
.. note::
The pypostgresql dialect is **not tested as part of SQLAlchemy's continuous
integration** and may have unresolved issues. The recommended PostgreSQL
driver is psycopg2.
.. deprecated:: 1.4 The py-postgresql DBAPI is deprecated and will be removed
in a future version. This DBAPI is superseded by the external
version available at external-dialect_. Please use the external version or
one of the supported DBAPIs to connect to PostgreSQL.
.. TODO update link
.. _external-dialect: https://github.com/PyGreSQL
""" # noqa
from .base import PGDialect
from .base import PGExecutionContext
from ... import processors
from ... import types as sqltypes
from ... import util
class PGNumeric(sqltypes.Numeric):
def bind_processor(self, dialect):
return processors.to_str
def result_processor(self, dialect, coltype):
if self.asdecimal:
return None
else:
return processors.to_float
class PGExecutionContext_pypostgresql(PGExecutionContext):
pass
class PGDialect_pypostgresql(PGDialect):
driver = "pypostgresql"
supports_statement_cache = True
supports_unicode_statements = True
supports_unicode_binds = True
description_encoding = None
default_paramstyle = "pyformat"
# requires trunk version to support sane rowcounts
# TODO: use dbapi version information to set this flag appropriately
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_pypostgresql
colspecs = util.update_copy(
PGDialect.colspecs,
{
sqltypes.Numeric: PGNumeric,
# prevents PGNumeric from being used
sqltypes.Float: sqltypes.Float,
},
)
@classmethod
def dbapi(cls):
from postgresql.driver import dbapi20
# TODO update link
util.warn_deprecated(
"The py-postgresql DBAPI is deprecated and will be removed "
"in a future version. This DBAPI is superseded by the external"
"version available at https://github.com/PyGreSQL. Please "
"use one of the supported DBAPIs to connect to PostgreSQL.",
version="1.4",
)
return dbapi20
_DBAPI_ERROR_NAMES = [
"Error",
"InterfaceError",
"DatabaseError",
"DataError",
"OperationalError",
"IntegrityError",
"InternalError",
"ProgrammingError",
"NotSupportedError",
]
@util.memoized_property
def dbapi_exception_translation_map(self):
if self.dbapi is None:
return {}
return dict(
(getattr(self.dbapi, name).__name__, name)
for name in self._DBAPI_ERROR_NAMES
)
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user")
if "port" in opts:
opts["port"] = int(opts["port"])
else:
opts["port"] = 5432
opts.update(url.query)
return ([], opts)
def is_disconnect(self, e, connection, cursor):
return "connection is closed" in str(e)
dialect = PGDialect_pypostgresql

View File

@@ -0,0 +1,138 @@
# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from ... import types as sqltypes
__all__ = ("INT4RANGE", "INT8RANGE", "NUMRANGE")
class RangeOperators(object):
"""
This mixin provides functionality for the Range Operators
listed in the Range Operators table of the `PostgreSQL documentation`__
for Range Functions and Operators. It is used by all the range types
provided in the ``postgres`` dialect and can likely be used for
any range types you create yourself.
__ https://www.postgresql.org/docs/current/static/functions-range.html
No extra support is provided for the Range Functions listed in the Range
Functions table of the PostgreSQL documentation. For these, the normal
:func:`~sqlalchemy.sql.expression.func` object should be used.
"""
class comparator_factory(sqltypes.Concatenable.Comparator):
"""Define comparison operations for range types."""
def __ne__(self, other):
"Boolean expression. Returns true if two ranges are not equal"
if other is None:
return super(RangeOperators.comparator_factory, self).__ne__(
other
)
else:
return self.expr.op("<>", is_comparison=True)(other)
def contains(self, other, **kw):
"""Boolean expression. Returns true if the right hand operand,
which can be an element or a range, is contained within the
column.
kwargs may be ignored by this operator but are required for API
conformance.
"""
return self.expr.op("@>", is_comparison=True)(other)
def contained_by(self, other):
"""Boolean expression. Returns true if the column is contained
within the right hand operand.
"""
return self.expr.op("<@", is_comparison=True)(other)
def overlaps(self, other):
"""Boolean expression. Returns true if the column overlaps
(has points in common with) the right hand operand.
"""
return self.expr.op("&&", is_comparison=True)(other)
def strictly_left_of(self, other):
"""Boolean expression. Returns true if the column is strictly
left of the right hand operand.
"""
return self.expr.op("<<", is_comparison=True)(other)
__lshift__ = strictly_left_of
def strictly_right_of(self, other):
"""Boolean expression. Returns true if the column is strictly
right of the right hand operand.
"""
return self.expr.op(">>", is_comparison=True)(other)
__rshift__ = strictly_right_of
def not_extend_right_of(self, other):
"""Boolean expression. Returns true if the range in the column
does not extend right of the range in the operand.
"""
return self.expr.op("&<", is_comparison=True)(other)
def not_extend_left_of(self, other):
"""Boolean expression. Returns true if the range in the column
does not extend left of the range in the operand.
"""
return self.expr.op("&>", is_comparison=True)(other)
def adjacent_to(self, other):
"""Boolean expression. Returns true if the range in the column
is adjacent to the range in the operand.
"""
return self.expr.op("-|-", is_comparison=True)(other)
def __add__(self, other):
"""Range expression. Returns the union of the two ranges.
Will raise an exception if the resulting range is not
contiguous.
"""
return self.expr.op("+")(other)
class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL INT4RANGE type."""
__visit_name__ = "INT4RANGE"
class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL INT8RANGE type."""
__visit_name__ = "INT8RANGE"
class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL NUMRANGE type."""
__visit_name__ = "NUMRANGE"
class DATERANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL DATERANGE type."""
__visit_name__ = "DATERANGE"
class TSRANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL TSRANGE type."""
__visit_name__ = "TSRANGE"
class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
"""Represent the PostgreSQL TSTZRANGE type."""
__visit_name__ = "TSTZRANGE"

View File

@@ -0,0 +1,58 @@
# sqlite/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from . import base # noqa
from . import pysqlcipher # noqa
from . import pysqlite # noqa
from .base import BLOB
from .base import BOOLEAN
from .base import CHAR
from .base import DATE
from .base import DATETIME
from .base import DECIMAL
from .base import FLOAT
from .base import INTEGER
from .base import JSON
from .base import NUMERIC
from .base import REAL
from .base import SMALLINT
from .base import TEXT
from .base import TIME
from .base import TIMESTAMP
from .base import VARCHAR
from .dml import Insert
from .dml import insert
from ...util import compat
if compat.py3k:
from . import aiosqlite # noqa
# default dialect
base.dialect = dialect = pysqlite.dialect
__all__ = (
"BLOB",
"BOOLEAN",
"CHAR",
"DATE",
"DATETIME",
"DECIMAL",
"FLOAT",
"INTEGER",
"JSON",
"NUMERIC",
"SMALLINT",
"TEXT",
"TIME",
"TIMESTAMP",
"VARCHAR",
"REAL",
"Insert",
"insert",
"dialect",
)

View File

@@ -0,0 +1,335 @@
# sqlite/aiosqlite.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: sqlite+aiosqlite
:name: aiosqlite
:dbapi: aiosqlite
:connectstring: sqlite+aiosqlite:///file_path
:url: https://pypi.org/project/aiosqlite/
The aiosqlite dialect provides support for the SQLAlchemy asyncio interface
running on top of pysqlite.
aiosqlite is a wrapper around pysqlite that uses a background thread for
each connection. It does not actually use non-blocking IO, as SQLite
databases are not socket-based. However it does provide a working asyncio
interface that's useful for testing and prototyping purposes.
Using a special asyncio mediation layer, the aiosqlite dialect is usable
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
extension package.
This dialect should normally be used only with the
:func:`_asyncio.create_async_engine` engine creation function::
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine("sqlite+aiosqlite:///filename")
The URL passes through all arguments to the ``pysqlite`` driver, so all
connection arguments are the same as they are for that of :ref:`pysqlite`.
""" # noqa
from .base import SQLiteExecutionContext
from .pysqlite import SQLiteDialect_pysqlite
from ... import pool
from ... import util
from ...engine import AdaptedConnection
from ...util.concurrency import await_fallback
from ...util.concurrency import await_only
class AsyncAdapt_aiosqlite_cursor:
__slots__ = (
"_adapt_connection",
"_connection",
"description",
"await_",
"_rows",
"arraysize",
"rowcount",
"lastrowid",
)
server_side = False
def __init__(self, adapt_connection):
self._adapt_connection = adapt_connection
self._connection = adapt_connection._connection
self.await_ = adapt_connection.await_
self.arraysize = 1
self.rowcount = -1
self.description = None
self._rows = []
def close(self):
self._rows[:] = []
def execute(self, operation, parameters=None):
try:
_cursor = self.await_(self._connection.cursor())
if parameters is None:
self.await_(_cursor.execute(operation))
else:
self.await_(_cursor.execute(operation, parameters))
if _cursor.description:
self.description = _cursor.description
self.lastrowid = self.rowcount = -1
if not self.server_side:
self._rows = self.await_(_cursor.fetchall())
else:
self.description = None
self.lastrowid = _cursor.lastrowid
self.rowcount = _cursor.rowcount
if not self.server_side:
self.await_(_cursor.close())
else:
self._cursor = _cursor
except Exception as error:
self._adapt_connection._handle_exception(error)
def executemany(self, operation, seq_of_parameters):
try:
_cursor = self.await_(self._connection.cursor())
self.await_(_cursor.executemany(operation, seq_of_parameters))
self.description = None
self.lastrowid = _cursor.lastrowid
self.rowcount = _cursor.rowcount
self.await_(_cursor.close())
except Exception as error:
self._adapt_connection._handle_exception(error)
def setinputsizes(self, *inputsizes):
pass
def __iter__(self):
while self._rows:
yield self._rows.pop(0)
def fetchone(self):
if self._rows:
return self._rows.pop(0)
else:
return None
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
retval = self._rows[0:size]
self._rows[:] = self._rows[size:]
return retval
def fetchall(self):
retval = self._rows[:]
self._rows[:] = []
return retval
class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
__slots__ = "_cursor"
server_side = True
def __init__(self, *arg, **kw):
super().__init__(*arg, **kw)
self._cursor = None
def close(self):
if self._cursor is not None:
self.await_(self._cursor.close())
self._cursor = None
def fetchone(self):
return self.await_(self._cursor.fetchone())
def fetchmany(self, size=None):
if size is None:
size = self.arraysize
return self.await_(self._cursor.fetchmany(size=size))
def fetchall(self):
return self.await_(self._cursor.fetchall())
class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
await_ = staticmethod(await_only)
__slots__ = ("dbapi", "_connection")
def __init__(self, dbapi, connection):
self.dbapi = dbapi
self._connection = connection
@property
def isolation_level(self):
return self._connection.isolation_level
@isolation_level.setter
def isolation_level(self, value):
try:
self._connection.isolation_level = value
except Exception as error:
self._handle_exception(error)
def create_function(self, *args, **kw):
try:
self.await_(self._connection.create_function(*args, **kw))
except Exception as error:
self._handle_exception(error)
def cursor(self, server_side=False):
if server_side:
return AsyncAdapt_aiosqlite_ss_cursor(self)
else:
return AsyncAdapt_aiosqlite_cursor(self)
def execute(self, *args, **kw):
return self.await_(self._connection.execute(*args, **kw))
def rollback(self):
try:
self.await_(self._connection.rollback())
except Exception as error:
self._handle_exception(error)
def commit(self):
try:
self.await_(self._connection.commit())
except Exception as error:
self._handle_exception(error)
def close(self):
try:
self.await_(self._connection.close())
except Exception as error:
self._handle_exception(error)
def _handle_exception(self, error):
if (
isinstance(error, ValueError)
and error.args[0] == "no active connection"
):
util.raise_(
self.dbapi.sqlite.OperationalError("no active connection"),
from_=error,
)
else:
raise error
class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
__slots__ = ()
await_ = staticmethod(await_fallback)
class AsyncAdapt_aiosqlite_dbapi:
def __init__(self, aiosqlite, sqlite):
self.aiosqlite = aiosqlite
self.sqlite = sqlite
self.paramstyle = "qmark"
self._init_dbapi_attributes()
def _init_dbapi_attributes(self):
for name in (
"DatabaseError",
"Error",
"IntegrityError",
"NotSupportedError",
"OperationalError",
"ProgrammingError",
"sqlite_version",
"sqlite_version_info",
):
setattr(self, name, getattr(self.aiosqlite, name))
for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"):
setattr(self, name, getattr(self.sqlite, name))
for name in ("Binary",):
setattr(self, name, getattr(self.sqlite, name))
def connect(self, *arg, **kw):
async_fallback = kw.pop("async_fallback", False)
# Q. WHY do we need this?
# A. Because there is no way to set connection.isolation_level
# otherwise
# Q. BUT HOW do you know it is SAFE ?????
# A. The only operation that isn't safe is the isolation level set
# operation which aiosqlite appears to have let slip through even
# though pysqlite appears to do check_same_thread for this.
# All execute operations etc. should be safe because they all
# go through the single executor thread.
kw["check_same_thread"] = False
connection = self.aiosqlite.connect(*arg, **kw)
# it's a Thread. you'll thank us later
connection.daemon = True
if util.asbool(async_fallback):
return AsyncAdaptFallback_aiosqlite_connection(
self,
await_fallback(connection),
)
else:
return AsyncAdapt_aiosqlite_connection(
self,
await_only(connection),
)
class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(server_side=True)
class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
driver = "aiosqlite"
supports_statement_cache = True
is_async = True
supports_server_side_cursors = True
execution_ctx_cls = SQLiteExecutionContext_aiosqlite
@classmethod
def dbapi(cls):
return AsyncAdapt_aiosqlite_dbapi(
__import__("aiosqlite"), __import__("sqlite3")
)
@classmethod
def get_pool_class(cls, url):
if cls._is_url_file_db(url):
return pool.NullPool
else:
return pool.StaticPool
def is_disconnect(self, e, connection, cursor):
if isinstance(
e, self.dbapi.OperationalError
) and "no active connection" in str(e):
return True
return super().is_disconnect(e, connection, cursor)
def get_driver_connection(self, connection):
return connection._connection
dialect = SQLiteDialect_aiosqlite

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,200 @@
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from ... import util
from ...sql import coercions
from ...sql import roles
from ...sql.base import _exclusive_against
from ...sql.base import _generative
from ...sql.base import ColumnCollection
from ...sql.dml import Insert as StandardInsert
from ...sql.elements import ClauseElement
from ...sql.expression import alias
from ...util.langhelpers import public_factory
__all__ = ("Insert", "insert")
class Insert(StandardInsert):
"""SQLite-specific implementation of INSERT.
Adds methods for SQLite-specific syntaxes such as ON CONFLICT.
The :class:`_sqlite.Insert` object is created using the
:func:`sqlalchemy.dialects.sqlite.insert` function.
.. versionadded:: 1.4
.. seealso::
:ref:`sqlite_on_conflict_insert`
"""
stringify_dialect = "sqlite"
inherit_cache = False
@util.memoized_property
def excluded(self):
"""Provide the ``excluded`` namespace for an ON CONFLICT statement
SQLite's ON CONFLICT clause allows reference to the row that would
be inserted, known as ``excluded``. This attribute provides
all columns in this row to be referenceable.
.. tip:: The :attr:`_sqlite.Insert.excluded` attribute is an instance
of :class:`_expression.ColumnCollection`, which provides an
interface the same as that of the :attr:`_schema.Table.c`
collection described at :ref:`metadata_tables_and_columns`.
With this collection, ordinary names are accessible like attributes
(e.g. ``stmt.excluded.some_column``), but special names and
dictionary method names should be accessed using indexed access,
such as ``stmt.excluded["column name"]`` or
``stmt.excluded["values"]``. See the docstring for
:class:`_expression.ColumnCollection` for further examples.
"""
return alias(self.table, name="excluded").columns
_on_conflict_exclusive = _exclusive_against(
"_post_values_clause",
msgs={
"_post_values_clause": "This Insert construct already has "
"an ON CONFLICT clause established"
},
)
@_generative
@_on_conflict_exclusive
def on_conflict_do_update(
self,
index_elements=None,
index_where=None,
set_=None,
where=None,
):
r"""
Specifies a DO UPDATE SET action for ON CONFLICT clause.
:param index_elements:
A sequence consisting of string column names, :class:`_schema.Column`
objects, or other column expression objects that will be used
to infer a target index or unique constraint.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
:param set\_:
A dictionary or other mapping object
where the keys are either names of columns in the target table,
or :class:`_schema.Column` objects or other ORM-mapped columns
matching that of the target table, and expressions or literals
as values, specifying the ``SET`` actions to take.
.. versionadded:: 1.4 The
:paramref:`_sqlite.Insert.on_conflict_do_update.set_`
parameter supports :class:`_schema.Column` objects from the target
:class:`_schema.Table` as keys.
.. warning:: This dictionary does **not** take into account
Python-specified default UPDATE values or generation functions,
e.g. those specified using :paramref:`_schema.Column.onupdate`.
These values will not be exercised for an ON CONFLICT style of
UPDATE, unless they are manually specified in the
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
:param where:
Optional argument. If present, can be a literal SQL
string or an acceptable expression for a ``WHERE`` clause
that restricts the rows affected by ``DO UPDATE SET``. Rows
not meeting the ``WHERE`` condition will not be updated
(effectively a ``DO NOTHING`` for those rows).
"""
self._post_values_clause = OnConflictDoUpdate(
index_elements, index_where, set_, where
)
@_generative
@_on_conflict_exclusive
def on_conflict_do_nothing(self, index_elements=None, index_where=None):
"""
Specifies a DO NOTHING action for ON CONFLICT clause.
:param index_elements:
A sequence consisting of string column names, :class:`_schema.Column`
objects, or other column expression objects that will be used
to infer a target index or unique constraint.
:param index_where:
Additional WHERE criterion that can be used to infer a
conditional target index.
"""
self._post_values_clause = OnConflictDoNothing(
index_elements, index_where
)
insert = public_factory(
Insert, ".dialects.sqlite.insert", ".dialects.sqlite.Insert"
)
class OnConflictClause(ClauseElement):
stringify_dialect = "sqlite"
def __init__(self, index_elements=None, index_where=None):
if index_elements is not None:
self.constraint_target = None
self.inferred_target_elements = index_elements
self.inferred_target_whereclause = index_where
else:
self.constraint_target = (
self.inferred_target_elements
) = self.inferred_target_whereclause = None
class OnConflictDoNothing(OnConflictClause):
__visit_name__ = "on_conflict_do_nothing"
class OnConflictDoUpdate(OnConflictClause):
__visit_name__ = "on_conflict_do_update"
def __init__(
self,
index_elements=None,
index_where=None,
set_=None,
where=None,
):
super(OnConflictDoUpdate, self).__init__(
index_elements=index_elements,
index_where=index_where,
)
if isinstance(set_, dict):
if not set_:
raise ValueError("set parameter dictionary must not be empty")
elif isinstance(set_, ColumnCollection):
set_ = dict(set_)
else:
raise ValueError(
"set parameter must be a non-empty dictionary "
"or a ColumnCollection such as the `.c.` collection "
"of a Table object"
)
self.update_values_to_set = [
(coercions.expect(roles.DMLColumnRole, key), value)
for key, value in set_.items()
]
self.update_whereclause = where

View File

@@ -0,0 +1,84 @@
from ... import types as sqltypes
class JSON(sqltypes.JSON):
"""SQLite JSON type.
SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note
that JSON1_ is a
`loadable extension <https://www.sqlite.org/loadext.html>`_ and as such
may not be available, or may require run-time loading.
:class:`_sqlite.JSON` is used automatically whenever the base
:class:`_types.JSON` datatype is used against a SQLite backend.
.. seealso::
:class:`_types.JSON` - main documentation for the generic
cross-platform JSON datatype.
The :class:`_sqlite.JSON` type supports persistence of JSON values
as well as the core index operations provided by :class:`_types.JSON`
datatype, by adapting the operations to render the ``JSON_EXTRACT``
function wrapped in the ``JSON_QUOTE`` function at the database level.
Extracted values are quoted in order to ensure that the results are
always JSON string values.
.. versionadded:: 1.3
.. _JSON1: https://www.sqlite.org/json1.html
"""
# Note: these objects currently match exactly those of MySQL, however since
# these are not generalizable to all JSON implementations, remain separately
# implemented for each dialect.
class _FormatTypeMixin(object):
def _format_value(self, value):
raise NotImplementedError()
def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)
def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value
return process
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
def _format_value(self, value):
if isinstance(value, int):
value = "$[%s]" % value
else:
value = '$."%s"' % value
return value
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
def _format_value(self, value):
return "$%s" % (
"".join(
[
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
for elem in value
]
)
)

View File

@@ -0,0 +1,142 @@
import os
import re
from ... import exc
from ...engine import url as sa_url
from ...testing.provision import create_db
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import generate_driver_url
from ...testing.provision import log
from ...testing.provision import post_configure_engine
from ...testing.provision import run_reap_dbs
from ...testing.provision import stop_test_class_outside_fixtures
from ...testing.provision import temp_table_keyword_args
# TODO: I can't get this to build dynamically with pytest-xdist procs
_drivernames = {"pysqlite", "aiosqlite", "pysqlcipher"}
@generate_driver_url.for_db("sqlite")
def generate_driver_url(url, driver, query_str):
if driver == "pysqlcipher" and url.get_driver_name() != "pysqlcipher":
if url.database:
url = url.set(database=url.database + ".enc")
url = url.set(password="test")
url = url.set(drivername="sqlite+%s" % (driver,))
try:
url.get_dialect()
except exc.NoSuchModuleError:
return None
else:
return url
@follower_url_from_main.for_db("sqlite")
def _sqlite_follower_url_from_main(url, ident):
url = sa_url.make_url(url)
if not url.database or url.database == ":memory:":
return url
else:
m = re.match(r"(.+?)\.(.+)$", url.database)
name, ext = m.group(1, 2)
drivername = url.get_driver_name()
return sa_url.make_url(
"sqlite+%s:///%s_%s.%s" % (drivername, drivername, ident, ext)
)
@post_configure_engine.for_db("sqlite")
def _sqlite_post_configure_engine(url, engine, follower_ident):
from sqlalchemy import event
@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
# use file DBs in all cases, memory acts kind of strangely
# as an attached
if not follower_ident:
# note this test_schema.db gets created for all test runs.
# there's not any dedicated cleanup step for it. it in some
# ways corresponds to the "test.test_schema" schema that's
# expected to be already present, so for now it just stays
# in a given checkout directory.
dbapi_connection.execute(
'ATTACH DATABASE "%s_test_schema.db" AS test_schema'
% (engine.driver,)
)
else:
dbapi_connection.execute(
'ATTACH DATABASE "%s_%s_test_schema.db" AS test_schema'
% (follower_ident, engine.driver)
)
@create_db.for_db("sqlite")
def _sqlite_create_db(cfg, eng, ident):
pass
@drop_db.for_db("sqlite")
def _sqlite_drop_db(cfg, eng, ident):
for path in [
"%s.db" % ident,
"%s_%s_test_schema.db" % (ident, eng.driver),
]:
if os.path.exists(path):
log.info("deleting SQLite database file: %s" % path)
os.remove(path)
@stop_test_class_outside_fixtures.for_db("sqlite")
def stop_test_class_outside_fixtures(config, db, cls):
with db.connect() as conn:
files = [
row.file
for row in conn.exec_driver_sql("PRAGMA database_list")
if row.file
]
if files:
db.dispose()
# some sqlite file tests are not cleaning up well yet, so do this
# just to make things simple for now
for file_ in files:
if file_ and os.path.exists(file_):
os.remove(file_)
@temp_table_keyword_args.for_db("sqlite")
def _sqlite_temp_table_keyword_args(cfg, eng):
return {"prefixes": ["TEMPORARY"]}
@run_reap_dbs.for_db("sqlite")
def _reap_sqlite_dbs(url, idents):
log.info("db reaper connecting to %r", url)
log.info("identifiers in file: %s", ", ".join(idents))
for ident in idents:
# we don't have a config so we can't call _sqlite_drop_db due to the
# decorator
for ext in ("db", "db.enc"):
for path in (
["%s.%s" % (ident, ext)]
+ [
"%s_%s.%s" % (drivername, ident, ext)
for drivername in _drivernames
]
+ [
"%s_test_schema.%s" % (drivername, ext)
for drivername in _drivernames
]
+ [
"%s_%s_test_schema.%s" % (ident, drivername, ext)
for drivername in _drivernames
]
):
if os.path.exists(path):
log.info("deleting SQLite database file: %s" % path)
os.remove(path)

View File

@@ -0,0 +1,164 @@
# sqlite/pysqlcipher.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: sqlite+pysqlcipher
:name: pysqlcipher
:dbapi: sqlcipher 3 or pysqlcipher
:connectstring: sqlite+pysqlcipher://:passphrase@/file_path[?kdf_iter=<iter>]
Dialect for support of DBAPIs that make use of the
`SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend.
Driver
------
Current dialect selection logic is:
* If the :paramref:`_sa.create_engine.module` parameter supplies a DBAPI module,
that module is used.
* Otherwise for Python 3, choose https://pypi.org/project/sqlcipher3/
* If not available, fall back to https://pypi.org/project/pysqlcipher3/
* For Python 2, https://pypi.org/project/pysqlcipher/ is used.
.. warning:: The ``pysqlcipher3`` and ``pysqlcipher`` DBAPI drivers are no
longer maintained; the ``sqlcipher3`` driver as of this writing appears
to be current. For future compatibility, any pysqlcipher-compatible DBAPI
may be used as follows::
import sqlcipher_compatible_driver
from sqlalchemy import create_engine
e = create_engine(
"sqlite+pysqlcipher://:password@/dbname.db",
module=sqlcipher_compatible_driver
)
These drivers make use of the SQLCipher engine. This system essentially
introduces new PRAGMA commands to SQLite which allows the setting of a
passphrase and other encryption parameters, allowing the database file to be
encrypted.
Connect Strings
---------------
The format of the connect string is in every way the same as that
of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
"password" field is now accepted, which should contain a passphrase::
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
For an absolute file path, two leading slashes should be used for the
database name::
e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
A selection of additional encryption-related pragmas supported by SQLCipher
as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
in the query string, and will result in that PRAGMA being called for each
new connection. Currently, ``cipher``, ``kdf_iter``
``cipher_page_size`` and ``cipher_use_hmac`` are supported::
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
.. warning:: Previous versions of sqlalchemy did not take into consideration
the encryption-related pragmas passed in the url string, that were silently
ignored. This may cause errors when opening files saved by a
previous sqlalchemy version if the encryption options do not match.
Pooling Behavior
----------------
The driver makes a change to the default pool behavior of pysqlite
as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver
has been observed to be significantly slower on connection than the
pysqlite driver, most likely due to the encryption overhead, so the
dialect here defaults to using the :class:`.SingletonThreadPool`
implementation,
instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool
implementation is entirely configurable using the
:paramref:`_sa.create_engine.poolclass` parameter; the :class:`.
StaticPool` may
be more feasible for single-threaded use, or :class:`.NullPool` may be used
to prevent unencrypted connections from being held open for long periods of
time, at the expense of slower startup time for new connections.
""" # noqa
from __future__ import absolute_import
from .pysqlite import SQLiteDialect_pysqlite
from ... import pool
from ... import util
class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
driver = "pysqlcipher"
supports_statement_cache = True
pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac")
@classmethod
def dbapi(cls):
if util.py3k:
try:
import sqlcipher3 as sqlcipher
except ImportError:
pass
else:
return sqlcipher
from pysqlcipher3 import dbapi2 as sqlcipher
else:
from pysqlcipher import dbapi2 as sqlcipher
return sqlcipher
@classmethod
def get_pool_class(cls, url):
return pool.SingletonThreadPool
def on_connect_url(self, url):
super_on_connect = super(
SQLiteDialect_pysqlcipher, self
).on_connect_url(url)
# pull the info we need from the URL early. Even though URL
# is immutable, we don't want any in-place changes to the URL
# to affect things
passphrase = url.password or ""
url_query = dict(url.query)
def on_connect(conn):
cursor = conn.cursor()
cursor.execute('pragma key="%s"' % passphrase)
for prag in self.pragmas:
value = url_query.get(prag, None)
if value is not None:
cursor.execute('pragma %s="%s"' % (prag, value))
cursor.close()
if super_on_connect:
super_on_connect(conn)
return on_connect
def create_connect_args(self, url):
plain_url = url._replace(password=None)
plain_url = plain_url.difference_update_query(self.pragmas)
return super(SQLiteDialect_pysqlcipher, self).create_connect_args(
plain_url
)
dialect = SQLiteDialect_pysqlcipher

View File

@@ -0,0 +1,613 @@
# sqlite/pysqlite.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
r"""
.. dialect:: sqlite+pysqlite
:name: pysqlite
:dbapi: sqlite3
:connectstring: sqlite+pysqlite:///file_path
:url: https://docs.python.org/library/sqlite3.html
Note that ``pysqlite`` is the same driver as the ``sqlite3``
module included with the Python distribution.
Driver
------
The ``sqlite3`` Python DBAPI is standard on all modern Python versions;
for cPython and Pypy, no additional installation is necessary.
Connect Strings
---------------
The file specification for the SQLite database is taken as the "database"
portion of the URL. Note that the format of a SQLAlchemy url is::
driver://user:pass@host/database
This means that the actual filename to be used starts with the characters to
the **right** of the third slash. So connecting to a relative filepath
looks like::
# relative path
e = create_engine('sqlite:///path/to/database.db')
An absolute path, which is denoted by starting with a slash, means you
need **four** slashes::
# absolute path
e = create_engine('sqlite:////path/to/database.db')
To use a Windows path, regular drive specifications and backslashes can be
used. Double backslashes are probably needed::
# absolute path on Windows
e = create_engine('sqlite:///C:\\path\\to\\database.db')
The sqlite ``:memory:`` identifier is the default if no filepath is
present. Specify ``sqlite://`` and nothing else::
# in-memory database
e = create_engine('sqlite://')
.. _pysqlite_uri_connections:
URI Connections
^^^^^^^^^^^^^^^
Modern versions of SQLite support an alternative system of connecting using a
`driver level URI <https://www.sqlite.org/uri.html>`_, which has the advantage
that additional driver-level arguments can be passed including options such as
"read only". The Python sqlite3 driver supports this mode under modern Python
3 versions. The SQLAlchemy pysqlite driver supports this mode of use by
specifying "uri=true" in the URL query string. The SQLite-level "URI" is kept
as the "database" portion of the SQLAlchemy url (that is, following a slash)::
e = create_engine("sqlite:///file:path/to/database?mode=ro&uri=true")
.. note:: The "uri=true" parameter must appear in the **query string**
of the URL. It will not currently work as expected if it is only
present in the :paramref:`_sa.create_engine.connect_args`
parameter dictionary.
The logic reconciles the simultaneous presence of SQLAlchemy's query string and
SQLite's query string by separating out the parameters that belong to the
Python sqlite3 driver vs. those that belong to the SQLite URI. This is
achieved through the use of a fixed list of parameters known to be accepted by
the Python side of the driver. For example, to include a URL that indicates
the Python sqlite3 "timeout" and "check_same_thread" parameters, along with the
SQLite "mode" and "nolock" parameters, they can all be passed together on the
query string::
e = create_engine(
"sqlite:///file:path/to/database?"
"check_same_thread=true&timeout=10&mode=ro&nolock=1&uri=true"
)
Above, the pysqlite / sqlite3 DBAPI would be passed arguments as::
sqlite3.connect(
"file:path/to/database?mode=ro&nolock=1",
check_same_thread=True, timeout=10, uri=True
)
Regarding future parameters added to either the Python or native drivers. new
parameter names added to the SQLite URI scheme should be automatically
accommodated by this scheme. New parameter names added to the Python driver
side can be accommodated by specifying them in the
:paramref:`_sa.create_engine.connect_args` dictionary,
until dialect support is
added by SQLAlchemy. For the less likely case that the native SQLite driver
adds a new parameter name that overlaps with one of the existing, known Python
driver parameters (such as "timeout" perhaps), SQLAlchemy's dialect would
require adjustment for the URL scheme to continue to support this.
As is always the case for all SQLAlchemy dialects, the entire "URL" process
can be bypassed in :func:`_sa.create_engine` through the use of the
:paramref:`_sa.create_engine.creator`
parameter which allows for a custom callable
that creates a Python sqlite3 driver level connection directly.
.. versionadded:: 1.3.9
.. seealso::
`Uniform Resource Identifiers <https://www.sqlite.org/uri.html>`_ - in
the SQLite documentation
.. _pysqlite_regexp:
Regular Expression Support
---------------------------
.. versionadded:: 1.4
Support for the :meth:`_sql.ColumnOperators.regexp_match` operator is provided
using Python's re.search_ function. SQLite itself does not include a working
regular expression operator; instead, it includes a non-implemented placeholder
operator ``REGEXP`` that calls a user-defined function that must be provided.
SQLAlchemy's implementation makes use of the pysqlite create_function_ hook
as follows::
def regexp(a, b):
return re.search(a, b) is not None
sqlite_connection.create_function(
"regexp", 2, regexp,
)
There is currently no support for regular expression flags as a separate
argument, as these are not supported by SQLite's REGEXP operator, however these
may be included inline within the regular expression string. See `Python regular expressions`_ for
details.
.. seealso::
`Python regular expressions`_: Documentation for Python's regular expression syntax.
.. _create_function: https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
.. _re.search: https://docs.python.org/3/library/re.html#re.search
.. _Python regular expressions: https://docs.python.org/3/library/re.html#re.search
Compatibility with sqlite3 "native" date and datetime types
-----------------------------------------------------------
The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
sqlite3.PARSE_COLNAMES options, which have the effect of any column
or expression explicitly cast as "date" or "timestamp" will be converted
to a Python date or datetime object. The date and datetime types provided
with the pysqlite dialect are not currently compatible with these options,
since they render the ISO date/datetime including microseconds, which
pysqlite's driver does not. Additionally, SQLAlchemy does not at
this time automatically render the "cast" syntax required for the
freestanding functions "current_timestamp" and "current_date" to return
datetime/date types natively. Unfortunately, pysqlite
does not provide the standard DBAPI types in ``cursor.description``,
leaving SQLAlchemy with no way to detect these types on the fly
without expensive per-row type checks.
Keeping in mind that pysqlite's parsing option is not recommended,
nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
can be forced if one configures "native_datetime=True" on create_engine()::
engine = create_engine('sqlite://',
connect_args={'detect_types':
sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
native_datetime=True
)
With this flag enabled, the DATE and TIMESTAMP types (but note - not the
DATETIME or TIME types...confused yet ?) will not perform any bind parameter
or result processing. Execution of "func.current_date()" will return a string.
"func.current_timestamp()" is registered as returning a DATETIME type in
SQLAlchemy, so this function still receives SQLAlchemy-level result
processing.
.. _pysqlite_threading_pooling:
Threading/Pooling Behavior
---------------------------
Pysqlite's default behavior is to prohibit the usage of a single connection
in more than one thread. This is originally intended to work with older
versions of SQLite that did not support multithreaded operation under
various circumstances. In particular, older SQLite versions
did not allow a ``:memory:`` database to be used in multiple threads
under any circumstances.
Pysqlite does include a now-undocumented flag known as
``check_same_thread`` which will disable this check, however note that
pysqlite connections are still not safe to use in concurrently in multiple
threads. In particular, any statement execution calls would need to be
externally mutexed, as Pysqlite does not provide for thread-safe propagation
of error messages among other things. So while even ``:memory:`` databases
can be shared among threads in modern SQLite, Pysqlite doesn't provide enough
thread-safety to make this usage worth it.
SQLAlchemy sets up pooling to work with Pysqlite's default behavior:
* When a ``:memory:`` SQLite database is specified, the dialect by default
will use :class:`.SingletonThreadPool`. This pool maintains a single
connection per thread, so that all access to the engine within the current
thread use the same ``:memory:`` database - other threads would access a
different ``:memory:`` database.
* When a file-based database is specified, the dialect will use
:class:`.NullPool` as the source of connections. This pool closes and
discards connections which are returned to the pool immediately. SQLite
file-based connections have extremely low overhead, so pooling is not
necessary. The scheme also prevents a connection from being used again in
a different thread and works best with SQLite's coarse-grained file locking.
Using a Memory Database in Multiple Threads
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
To use a ``:memory:`` database in a multithreaded scenario, the same
connection object must be shared among threads, since the database exists
only within the scope of that connection. The
:class:`.StaticPool` implementation will maintain a single connection
globally, and the ``check_same_thread`` flag can be passed to Pysqlite
as ``False``::
from sqlalchemy.pool import StaticPool
engine = create_engine('sqlite://',
connect_args={'check_same_thread':False},
poolclass=StaticPool)
Note that using a ``:memory:`` database in multiple threads requires a recent
version of SQLite.
Using Temporary Tables with SQLite
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Due to the way SQLite deals with temporary tables, if you wish to use a
temporary table in a file-based SQLite database across multiple checkouts
from the connection pool, such as when using an ORM :class:`.Session` where
the temporary table should continue to remain after :meth:`.Session.commit` or
:meth:`.Session.rollback` is called, a pool which maintains a single
connection must be used. Use :class:`.SingletonThreadPool` if the scope is
only needed within the current thread, or :class:`.StaticPool` is scope is
needed within multiple threads for this case::
# maintain the same connection per thread
from sqlalchemy.pool import SingletonThreadPool
engine = create_engine('sqlite:///mydb.db',
poolclass=SingletonThreadPool)
# maintain the same connection across all threads
from sqlalchemy.pool import StaticPool
engine = create_engine('sqlite:///mydb.db',
poolclass=StaticPool)
Note that :class:`.SingletonThreadPool` should be configured for the number
of threads that are to be used; beyond that number, connections will be
closed out in a non deterministic way.
Unicode
-------
The pysqlite driver only returns Python ``unicode`` objects in result sets,
never plain strings, and accommodates ``unicode`` objects within bound
parameter values in all cases. Regardless of the SQLAlchemy string type in
use, string-based result values will by Python ``unicode`` in Python 2.
The :class:`.Unicode` type should still be used to indicate those columns that
require unicode, however, so that non-``unicode`` values passed inadvertently
will emit a warning. Pysqlite will emit an error if a non-``unicode`` string
is passed containing non-ASCII characters.
Dealing with Mixed String / Binary Columns in Python 3
------------------------------------------------------
The SQLite database is weakly typed, and as such it is possible when using
binary values, which in Python 3 are represented as ``b'some string'``, that a
particular SQLite database can have data values within different rows where
some of them will be returned as a ``b''`` value by the Pysqlite driver, and
others will be returned as Python strings, e.g. ``''`` values. This situation
is not known to occur if the SQLAlchemy :class:`.LargeBinary` datatype is used
consistently, however if a particular SQLite database has data that was
inserted using the Pysqlite driver directly, or when using the SQLAlchemy
:class:`.String` type which was later changed to :class:`.LargeBinary`, the
table will not be consistently readable because SQLAlchemy's
:class:`.LargeBinary` datatype does not handle strings so it has no way of
"encoding" a value that is in string format.
To deal with a SQLite table that has mixed string / binary data in the
same column, use a custom type that will check each row individually::
# note this is Python 3 only
from sqlalchemy import String
from sqlalchemy import TypeDecorator
class MixedBinary(TypeDecorator):
impl = String
cache_ok = True
def process_result_value(self, value, dialect):
if isinstance(value, str):
value = bytes(value, 'utf-8')
elif value is not None:
value = bytes(value)
return value
Then use the above ``MixedBinary`` datatype in the place where
:class:`.LargeBinary` would normally be used.
.. _pysqlite_serializable:
Serializable isolation / Savepoints / Transactional DDL
-------------------------------------------------------
In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
driver's assortment of issues that prevent several features of SQLite
from working correctly. The pysqlite DBAPI driver has several
long-standing bugs which impact the correctness of its transactional
behavior. In its default mode of operation, SQLite features such as
SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
non-functional, and in order to use these features, workarounds must
be taken.
The issue is essentially that the driver attempts to second-guess the user's
intent, failing to start transactions and sometimes ending them prematurely, in
an effort to minimize the SQLite databases's file locking behavior, even
though SQLite itself uses "shared" locks for read-only activities.
SQLAlchemy chooses to not alter this behavior by default, as it is the
long-expected behavior of the pysqlite driver; if and when the pysqlite
driver attempts to repair these issues, that will be more of a driver towards
defaults for SQLAlchemy.
The good news is that with a few events, we can implement transactional
support fully, by disabling pysqlite's feature entirely and emitting BEGIN
ourselves. This is achieved using two event listeners::
from sqlalchemy import create_engine, event
engine = create_engine("sqlite:///myfile.db")
@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None
@event.listens_for(engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.exec_driver_sql("BEGIN")
.. warning:: When using the above recipe, it is advised to not use the
:paramref:`.Connection.execution_options.isolation_level` setting on
:class:`_engine.Connection` and :func:`_sa.create_engine`
with the SQLite driver,
as this function necessarily will also alter the ".isolation_level" setting.
Above, we intercept a new pysqlite connection and disable any transactional
integration. Then, at the point at which SQLAlchemy knows that transaction
scope is to begin, we emit ``"BEGIN"`` ourselves.
When we take control of ``"BEGIN"``, we can also control directly SQLite's
locking modes, introduced at
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_,
by adding the desired locking mode to our ``"BEGIN"``::
@event.listens_for(engine, "begin")
def do_begin(conn):
conn.exec_driver_sql("BEGIN EXCLUSIVE")
.. seealso::
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_ -
on the SQLite site
`sqlite3 SELECT does not BEGIN a transaction <https://bugs.python.org/issue9924>`_ -
on the Python bug tracker
`sqlite3 module breaks transactions and potentially corrupts data <https://bugs.python.org/issue10740>`_ -
on the Python bug tracker
""" # noqa
import os
import re
from .base import DATE
from .base import DATETIME
from .base import SQLiteDialect
from ... import exc
from ... import pool
from ... import types as sqltypes
from ... import util
class _SQLite_pysqliteTimeStamp(DATETIME):
def bind_processor(self, dialect):
if dialect.native_datetime:
return None
else:
return DATETIME.bind_processor(self, dialect)
def result_processor(self, dialect, coltype):
if dialect.native_datetime:
return None
else:
return DATETIME.result_processor(self, dialect, coltype)
class _SQLite_pysqliteDate(DATE):
def bind_processor(self, dialect):
if dialect.native_datetime:
return None
else:
return DATE.bind_processor(self, dialect)
def result_processor(self, dialect, coltype):
if dialect.native_datetime:
return None
else:
return DATE.result_processor(self, dialect, coltype)
class SQLiteDialect_pysqlite(SQLiteDialect):
default_paramstyle = "qmark"
supports_statement_cache = True
colspecs = util.update_copy(
SQLiteDialect.colspecs,
{
sqltypes.Date: _SQLite_pysqliteDate,
sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
},
)
if not util.py2k:
description_encoding = None
driver = "pysqlite"
@classmethod
def dbapi(cls):
if util.py2k:
try:
from pysqlite2 import dbapi2 as sqlite
except ImportError:
try:
from sqlite3 import dbapi2 as sqlite
except ImportError as e:
raise e
else:
from sqlite3 import dbapi2 as sqlite
return sqlite
@classmethod
def _is_url_file_db(cls, url):
if (url.database and url.database != ":memory:") and (
url.query.get("mode", None) != "memory"
):
return True
else:
return False
@classmethod
def get_pool_class(cls, url):
if cls._is_url_file_db(url):
return pool.NullPool
else:
return pool.SingletonThreadPool
def _get_server_version_info(self, connection):
return self.dbapi.sqlite_version_info
_isolation_lookup = SQLiteDialect._isolation_lookup.union(
{
"AUTOCOMMIT": None,
}
)
def set_isolation_level(self, connection, level):
if hasattr(connection, "dbapi_connection"):
dbapi_connection = connection.dbapi_connection
else:
dbapi_connection = connection
if level == "AUTOCOMMIT":
dbapi_connection.isolation_level = None
else:
dbapi_connection.isolation_level = ""
return super(SQLiteDialect_pysqlite, self).set_isolation_level(
connection, level
)
def on_connect(self):
connect = super(SQLiteDialect_pysqlite, self).on_connect()
def regexp(a, b):
if b is None:
return None
return re.search(a, b) is not None
def set_regexp(connection):
if hasattr(connection, "dbapi_connection"):
dbapi_connection = connection.dbapi_connection
else:
dbapi_connection = connection
dbapi_connection.create_function(
"regexp",
2,
regexp,
)
fns = [set_regexp]
if self.isolation_level is not None:
def iso_level(conn):
self.set_isolation_level(conn, self.isolation_level)
fns.append(iso_level)
def connect(conn):
for fn in fns:
fn(conn)
return connect
def create_connect_args(self, url):
if url.username or url.password or url.host or url.port:
raise exc.ArgumentError(
"Invalid SQLite URL: %s\n"
"Valid SQLite URL forms are:\n"
" sqlite:///:memory: (or, sqlite://)\n"
" sqlite:///relative/path/to/file.db\n"
" sqlite:////absolute/path/to/file.db" % (url,)
)
# theoretically, this list can be augmented, at least as far as
# parameter names accepted by sqlite3/pysqlite, using
# inspect.getfullargspec(). for the moment this seems like overkill
# as these parameters don't change very often, and as always,
# parameters passed to connect_args will always go to the
# sqlite3/pysqlite driver.
pysqlite_args = [
("uri", bool),
("timeout", float),
("isolation_level", str),
("detect_types", int),
("check_same_thread", bool),
("cached_statements", int),
]
opts = url.query
pysqlite_opts = {}
for key, type_ in pysqlite_args:
util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts)
if pysqlite_opts.get("uri", False):
uri_opts = dict(opts)
# here, we are actually separating the parameters that go to
# sqlite3/pysqlite vs. those that go the SQLite URI. What if
# two names conflict? again, this seems to be not the case right
# now, and in the case that new names are added to
# either side which overlap, again the sqlite3/pysqlite parameters
# can be passed through connect_args instead of in the URL.
# If SQLite native URIs add a parameter like "timeout" that
# we already have listed here for the python driver, then we need
# to adjust for that here.
for key, type_ in pysqlite_args:
uri_opts.pop(key, None)
filename = url.database
if uri_opts:
# sorting of keys is for unit test support
filename += "?" + (
"&".join(
"%s=%s" % (key, uri_opts[key])
for key in sorted(uri_opts)
)
)
else:
filename = url.database or ":memory:"
if filename != ":memory:":
filename = os.path.abspath(filename)
return ([filename], pysqlite_opts)
def is_disconnect(self, e, connection, cursor):
return isinstance(
e, self.dbapi.ProgrammingError
) and "Cannot operate on a closed database." in str(e)
dialect = SQLiteDialect_pysqlite

View File

@@ -0,0 +1,67 @@
# sybase/__init__.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from . import base # noqa
from . import pyodbc # noqa
from . import pysybase # noqa
from .base import BIGINT
from .base import BINARY
from .base import BIT
from .base import CHAR
from .base import DATE
from .base import DATETIME
from .base import FLOAT
from .base import IMAGE
from .base import INT
from .base import INTEGER
from .base import MONEY
from .base import NCHAR
from .base import NUMERIC
from .base import NVARCHAR
from .base import SMALLINT
from .base import SMALLMONEY
from .base import TEXT
from .base import TIME
from .base import TINYINT
from .base import UNICHAR
from .base import UNITEXT
from .base import UNIVARCHAR
from .base import VARBINARY
from .base import VARCHAR
# default dialect
base.dialect = dialect = pyodbc.dialect
__all__ = (
"CHAR",
"VARCHAR",
"TIME",
"NCHAR",
"NVARCHAR",
"TEXT",
"DATE",
"DATETIME",
"FLOAT",
"NUMERIC",
"BIGINT",
"INT",
"INTEGER",
"SMALLINT",
"BINARY",
"VARBINARY",
"UNITEXT",
"UNICHAR",
"UNIVARCHAR",
"IMAGE",
"BIT",
"MONEY",
"SMALLMONEY",
"TINYINT",
"dialect",
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,34 @@
# sybase/mxodbc.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: sybase+mxodbc
:name: mxODBC
:dbapi: mxodbc
:connectstring: sybase+mxodbc://<username>:<password>@<dsnname>
:url: https://www.egenix.com/
.. note::
This dialect is a stub only and is likely non functional at this time.
"""
from sqlalchemy.connectors.mxodbc import MxODBCConnector
from sqlalchemy.dialects.sybase.base import SybaseDialect
from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
class SybaseExecutionContext_mxodbc(SybaseExecutionContext):
pass
class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_mxodbc
supports_statement_cache = True
dialect = SybaseDialect_mxodbc

View File

@@ -0,0 +1,89 @@
# sybase/pyodbc.py
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: sybase+pyodbc
:name: PyODBC
:dbapi: pyodbc
:connectstring: sybase+pyodbc://<username>:<password>@<dsnname>[/<database>]
:url: https://pypi.org/project/pyodbc/
Unicode Support
---------------
The pyodbc driver currently supports usage of these Sybase types with
Unicode or multibyte strings::
CHAR
NCHAR
NVARCHAR
TEXT
VARCHAR
Currently *not* supported are::
UNICHAR
UNITEXT
UNIVARCHAR
""" # noqa
import decimal
from sqlalchemy import processors
from sqlalchemy import types as sqltypes
from sqlalchemy.connectors.pyodbc import PyODBCConnector
from sqlalchemy.dialects.sybase.base import SybaseDialect
from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
class _SybNumeric_pyodbc(sqltypes.Numeric):
"""Turns Decimals with adjusted() < -6 into floats.
It's not yet known how to get decimals with many
significant digits or very large adjusted() into Sybase
via pyodbc.
"""
def bind_processor(self, dialect):
super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect)
def process(value):
if self.asdecimal and isinstance(value, decimal.Decimal):
if value.adjusted() < -6:
return processors.to_float(value)
if super_process:
return super_process(value)
else:
return value
return process
class SybaseExecutionContext_pyodbc(SybaseExecutionContext):
def set_ddl_autocommit(self, connection, value):
if value:
connection.autocommit = True
else:
connection.autocommit = False
class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect):
execution_ctx_cls = SybaseExecutionContext_pyodbc
supports_statement_cache = True
colspecs = {sqltypes.Numeric: _SybNumeric_pyodbc}
@classmethod
def dbapi(cls):
return PyODBCConnector.dbapi()
dialect = SybaseDialect_pyodbc

View File

@@ -0,0 +1,106 @@
# sybase/pysybase.py
# Copyright (C) 2010-2022 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""
.. dialect:: sybase+pysybase
:name: Python-Sybase
:dbapi: Sybase
:connectstring: sybase+pysybase://<username>:<password>@<dsn>/[database name]
:url: https://python-sybase.sourceforge.net/
Unicode Support
---------------
The python-sybase driver does not appear to support non-ASCII strings of any
kind at this time.
""" # noqa
from sqlalchemy import processors
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.sybase.base import SybaseDialect
from sqlalchemy.dialects.sybase.base import SybaseExecutionContext
from sqlalchemy.dialects.sybase.base import SybaseSQLCompiler
class _SybNumeric(sqltypes.Numeric):
def result_processor(self, dialect, type_):
if not self.asdecimal:
return processors.to_float
else:
return sqltypes.Numeric.result_processor(self, dialect, type_)
class SybaseExecutionContext_pysybase(SybaseExecutionContext):
def set_ddl_autocommit(self, dbapi_connection, value):
if value:
# call commit() on the Sybase connection directly,
# to avoid any side effects of calling a Connection
# transactional method inside of pre_exec()
dbapi_connection.commit()
def pre_exec(self):
SybaseExecutionContext.pre_exec(self)
for param in self.parameters:
for key in list(param):
param["@" + key] = param[key]
del param[key]
class SybaseSQLCompiler_pysybase(SybaseSQLCompiler):
def bindparam_string(self, name, **kw):
return "@" + name
class SybaseDialect_pysybase(SybaseDialect):
driver = "pysybase"
execution_ctx_cls = SybaseExecutionContext_pysybase
statement_compiler = SybaseSQLCompiler_pysybase
supports_statement_cache = True
colspecs = {sqltypes.Numeric: _SybNumeric, sqltypes.Float: sqltypes.Float}
@classmethod
def dbapi(cls):
import Sybase
return Sybase
def create_connect_args(self, url):
opts = url.translate_connect_args(username="user", password="passwd")
return ([opts.pop("host")], opts)
def do_executemany(self, cursor, statement, parameters, context=None):
# calling python-sybase executemany yields:
# TypeError: string too long for buffer
for param in parameters:
cursor.execute(statement, param)
def _get_server_version_info(self, connection):
vers = connection.exec_driver_sql("select @@version_number").scalar()
# i.e. 15500, 15000, 12500 == (15, 5, 0, 0), (15, 0, 0, 0),
# (12, 5, 0, 0)
return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10)
def is_disconnect(self, e, connection, cursor):
if isinstance(
e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError)
):
msg = str(e)
return (
"Unable to complete network request to host" in msg
or "Invalid connection state" in msg
or "Invalid cursor state" in msg
)
else:
return False
dialect = SybaseDialect_pysybase